import os import pickle import numpy as np import gradio as gr from dataclasses import dataclass, field from sentence_transformers import SentenceTransformer from pydantic_ai import Agent, RunContext from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider from typing import List, Dict # --- CONFIGURATION --- CACHE_PATH = "vector_store_cache.pkl" MODEL_NAME = "gemini-2.5-flash-lite" ACCESS_PASSWORD = "secret-mitrp-password" # ========================================== # PART 1: BACKEND LOGIC (RAG & AGENT) # ========================================== @dataclass class VectorStore: chunks: List[Dict] = field(default_factory=list) # each: {text, page_start, page_end, chunk_id} embeddings: np.ndarray = field(default_factory=lambda: np.array([])) def search(self, query: str, model: SentenceTransformer, top_k: int = 5) -> List[Dict]: if len(self.chunks) == 0: return [] query_embedding = model.encode([query])[0] query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9) norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-9 normalized = self.embeddings / norms similarities = normalized @ query_norm top_indices = np.argsort(similarities)[-top_k:][::-1] return [ { "text": self.chunks[i]["text"], "score": float(similarities[i]), "pages": f"{self.chunks[i].get('page_start', '?')}–{self.chunks[i].get('page_end', '?')}", } for i in top_indices ] def load_vector_store() -> VectorStore: """Load pre-built index from cache. Raises if missing.""" if not os.path.exists(CACHE_PATH): raise FileNotFoundError( f"Cache file '{CACHE_PATH}' not found. " "Run `uv run build_index.py` to generate it, then commit it to your repo." ) print(f"⏳ Loading vector store from {CACHE_PATH}...") with open(CACHE_PATH, "rb") as f: data = pickle.load(f) chunks = data["chunks"] embeddings = data["embeddings"] print(f"✅ Loaded {len(chunks)} chunks.") return VectorStore(chunks=chunks, embeddings=embeddings) # Initialize embedding model and vector store at startup print("⏳ Loading embedding model...") embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") global_vector_store = load_vector_store() # Initialize Pydantic AI Agent api_key = os.getenv("GEMINI_API_KEY") agent = None if api_key: provider = GoogleProvider(api_key=api_key) model = GoogleModel(MODEL_NAME, provider=provider) agent = Agent( model, deps_type=VectorStore, system_prompt=( "You are an expert on MITRP Policies. " "Always call `search_policy` to retrieve relevant excerpts before answering. " "Cite the page numbers provided in each excerpt. " "If the retrieved text does not contain the answer, say so explicitly." ), ) @agent.tool def search_policy(ctx: RunContext[VectorStore], query: str) -> str: """Search the MITRP policy document for relevant excerpts.""" results = ctx.deps.search(query, embed_model, top_k=5) if not results: return "No relevant policy sections found." return "\n\n".join( f"--- Excerpt (p. {r['pages']}, relevance {r['score']:.2f}) ---\n{r['text']}" for r in results ) else: print("⚠️ GEMINI_API_KEY not set — agent will not function.") # ========================================== # PART 2: FRONTEND LOGIC (UI & AUTH) # ========================================== async def chat_logic(message, history): if not agent: return "⚠️ Error: GEMINI_API_KEY is not configured." try: result = await agent.run(message, deps=global_vector_store) return getattr(result, "output", getattr(result, "data", str(result))) except Exception as e: return f"Error: {str(e)}" def login_logic(password): if password == ACCESS_PASSWORD: return gr.update(visible=False), gr.update(visible=True), "" return ( gr.update(visible=True), gr.update(visible=False), "
❌ Incorrect Password
", ) # --- GRADIO BLOCKS LAYOUT --- custom_css = "footer {visibility: hidden}" with gr.Blocks(title="MITRP Policy Assistant") as app: # --- SCREEN 1: LOGIN --- with gr.Column(visible=True) as login_col: gr.Markdown("## 🔒 MITRP Policy Bot\nPlease enter the access password to continue.") with gr.Row(): pass_input = gr.Textbox( label="Password", type="password", placeholder="Enter password...", show_label=False, scale=4, ) login_btn = gr.Button("Login", variant="primary", scale=1) error_msg = gr.Markdown("") # --- SCREEN 2: CHAT --- with gr.Column(visible=False) as chat_col: gr.Markdown("## 🏛️ MITRP Policy Assistant") chat_interface = gr.ChatInterface( fn=chat_logic, examples=[ "How many papers should I write per year?", "What is the vacation policy?", "How do I connect to the GPU machines?", ], ) # --- EVENT LISTENERS --- login_btn.click( fn=login_logic, inputs=[pass_input], outputs=[login_col, chat_col, error_msg], ) pass_input.submit( fn=login_logic, inputs=[pass_input], outputs=[login_col, chat_col, error_msg], ) if __name__ == "__main__": app.launch(theme="soft", css=custom_css)