Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import numpy as np | |
| import gradio as gr | |
| from dataclasses import dataclass, field | |
| from sentence_transformers import SentenceTransformer | |
| from pydantic_ai import Agent, RunContext | |
| from pydantic_ai.models.google import GoogleModel | |
| from pydantic_ai.providers.google import GoogleProvider | |
| from typing import List, Dict | |
| # --- CONFIGURATION --- | |
| CACHE_PATH = "vector_store_cache.pkl" | |
| MODEL_NAME = "gemini-2.5-flash-lite" | |
| ACCESS_PASSWORD = "secret-mitrp-password" | |
| # ========================================== | |
| # PART 1: BACKEND LOGIC (RAG & AGENT) | |
| # ========================================== | |
| class VectorStore: | |
| chunks: List[Dict] = field(default_factory=list) # each: {text, page_start, page_end, chunk_id} | |
| embeddings: np.ndarray = field(default_factory=lambda: np.array([])) | |
| def search(self, query: str, model: SentenceTransformer, top_k: int = 5) -> List[Dict]: | |
| if len(self.chunks) == 0: | |
| return [] | |
| query_embedding = model.encode([query])[0] | |
| query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9) | |
| norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-9 | |
| normalized = self.embeddings / norms | |
| similarities = normalized @ query_norm | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| return [ | |
| { | |
| "text": self.chunks[i]["text"], | |
| "score": float(similarities[i]), | |
| "pages": f"{self.chunks[i].get('page_start', '?')}β{self.chunks[i].get('page_end', '?')}", | |
| } | |
| for i in top_indices | |
| ] | |
| def load_vector_store() -> VectorStore: | |
| """Load pre-built index from cache. Raises if missing.""" | |
| if not os.path.exists(CACHE_PATH): | |
| raise FileNotFoundError( | |
| f"Cache file '{CACHE_PATH}' not found. " | |
| "Run `uv run build_index.py` to generate it, then commit it to your repo." | |
| ) | |
| print(f"β³ Loading vector store from {CACHE_PATH}...") | |
| with open(CACHE_PATH, "rb") as f: | |
| data = pickle.load(f) | |
| chunks = data["chunks"] | |
| embeddings = data["embeddings"] | |
| print(f"β Loaded {len(chunks)} chunks.") | |
| return VectorStore(chunks=chunks, embeddings=embeddings) | |
| # Initialize embedding model and vector store at startup | |
| print("β³ Loading embedding model...") | |
| embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| global_vector_store = load_vector_store() | |
| # Initialize Pydantic AI Agent | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| agent = None | |
| if api_key: | |
| provider = GoogleProvider(api_key=api_key) | |
| model = GoogleModel(MODEL_NAME, provider=provider) | |
| agent = Agent( | |
| model, | |
| deps_type=VectorStore, | |
| system_prompt=( | |
| "You are an expert on MITRP Policies. " | |
| "Always call `search_policy` to retrieve relevant excerpts before answering. " | |
| "Cite the page numbers provided in each excerpt. " | |
| "If the retrieved text does not contain the answer, say so explicitly." | |
| ), | |
| ) | |
| def search_policy(ctx: RunContext[VectorStore], query: str) -> str: | |
| """Search the MITRP policy document for relevant excerpts.""" | |
| results = ctx.deps.search(query, embed_model, top_k=5) | |
| if not results: | |
| return "No relevant policy sections found." | |
| return "\n\n".join( | |
| f"--- Excerpt (p. {r['pages']}, relevance {r['score']:.2f}) ---\n{r['text']}" | |
| for r in results | |
| ) | |
| else: | |
| print("β οΈ GEMINI_API_KEY not set β agent will not function.") | |
| # ========================================== | |
| # PART 2: FRONTEND LOGIC (UI & AUTH) | |
| # ========================================== | |
| async def chat_logic(message, history): | |
| if not agent: | |
| return "β οΈ Error: GEMINI_API_KEY is not configured." | |
| try: | |
| result = await agent.run(message, deps=global_vector_store) | |
| return getattr(result, "output", getattr(result, "data", str(result))) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def login_logic(password): | |
| if password == ACCESS_PASSWORD: | |
| return gr.update(visible=False), gr.update(visible=True), "" | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| "<p style='color:red'>β Incorrect Password</p>", | |
| ) | |
| # --- GRADIO BLOCKS LAYOUT --- | |
| custom_css = "footer {visibility: hidden}" | |
| with gr.Blocks(title="MITRP Policy Assistant") as app: | |
| # --- SCREEN 1: LOGIN --- | |
| with gr.Column(visible=True) as login_col: | |
| gr.Markdown("## π MITRP Policy Bot\nPlease enter the access password to continue.") | |
| with gr.Row(): | |
| pass_input = gr.Textbox( | |
| label="Password", | |
| type="password", | |
| placeholder="Enter password...", | |
| show_label=False, | |
| scale=4, | |
| ) | |
| login_btn = gr.Button("Login", variant="primary", scale=1) | |
| error_msg = gr.Markdown("") | |
| # --- SCREEN 2: CHAT --- | |
| with gr.Column(visible=False) as chat_col: | |
| gr.Markdown("## ποΈ MITRP Policy Assistant") | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_logic, | |
| examples=[ | |
| "How many papers should I write per year?", | |
| "What is the vacation policy?", | |
| "How do I connect to the GPU machines?", | |
| ], | |
| ) | |
| # --- EVENT LISTENERS --- | |
| login_btn.click( | |
| fn=login_logic, | |
| inputs=[pass_input], | |
| outputs=[login_col, chat_col, error_msg], | |
| ) | |
| pass_input.submit( | |
| fn=login_logic, | |
| inputs=[pass_input], | |
| outputs=[login_col, chat_col, error_msg], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(theme="soft", css=custom_css) |