alanmcmillan's picture
Update app.py
15d8c4e verified
import os
import pickle
import numpy as np
import gradio as gr
from dataclasses import dataclass, field
from sentence_transformers import SentenceTransformer
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
from typing import List, Dict
# --- CONFIGURATION ---
CACHE_PATH = "vector_store_cache.pkl"
MODEL_NAME = "gemini-2.5-flash-lite"
ACCESS_PASSWORD = "secret-mitrp-password"
# ==========================================
# PART 1: BACKEND LOGIC (RAG & AGENT)
# ==========================================
@dataclass
class VectorStore:
chunks: List[Dict] = field(default_factory=list) # each: {text, page_start, page_end, chunk_id}
embeddings: np.ndarray = field(default_factory=lambda: np.array([]))
def search(self, query: str, model: SentenceTransformer, top_k: int = 5) -> List[Dict]:
if len(self.chunks) == 0:
return []
query_embedding = model.encode([query])[0]
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9)
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-9
normalized = self.embeddings / norms
similarities = normalized @ query_norm
top_indices = np.argsort(similarities)[-top_k:][::-1]
return [
{
"text": self.chunks[i]["text"],
"score": float(similarities[i]),
"pages": f"{self.chunks[i].get('page_start', '?')}–{self.chunks[i].get('page_end', '?')}",
}
for i in top_indices
]
def load_vector_store() -> VectorStore:
"""Load pre-built index from cache. Raises if missing."""
if not os.path.exists(CACHE_PATH):
raise FileNotFoundError(
f"Cache file '{CACHE_PATH}' not found. "
"Run `uv run build_index.py` to generate it, then commit it to your repo."
)
print(f"⏳ Loading vector store from {CACHE_PATH}...")
with open(CACHE_PATH, "rb") as f:
data = pickle.load(f)
chunks = data["chunks"]
embeddings = data["embeddings"]
print(f"βœ… Loaded {len(chunks)} chunks.")
return VectorStore(chunks=chunks, embeddings=embeddings)
# Initialize embedding model and vector store at startup
print("⏳ Loading embedding model...")
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
global_vector_store = load_vector_store()
# Initialize Pydantic AI Agent
api_key = os.getenv("GEMINI_API_KEY")
agent = None
if api_key:
provider = GoogleProvider(api_key=api_key)
model = GoogleModel(MODEL_NAME, provider=provider)
agent = Agent(
model,
deps_type=VectorStore,
system_prompt=(
"You are an expert on MITRP Policies. "
"Always call `search_policy` to retrieve relevant excerpts before answering. "
"Cite the page numbers provided in each excerpt. "
"If the retrieved text does not contain the answer, say so explicitly."
),
)
@agent.tool
def search_policy(ctx: RunContext[VectorStore], query: str) -> str:
"""Search the MITRP policy document for relevant excerpts."""
results = ctx.deps.search(query, embed_model, top_k=5)
if not results:
return "No relevant policy sections found."
return "\n\n".join(
f"--- Excerpt (p. {r['pages']}, relevance {r['score']:.2f}) ---\n{r['text']}"
for r in results
)
else:
print("⚠️ GEMINI_API_KEY not set β€” agent will not function.")
# ==========================================
# PART 2: FRONTEND LOGIC (UI & AUTH)
# ==========================================
async def chat_logic(message, history):
if not agent:
return "⚠️ Error: GEMINI_API_KEY is not configured."
try:
result = await agent.run(message, deps=global_vector_store)
return getattr(result, "output", getattr(result, "data", str(result)))
except Exception as e:
return f"Error: {str(e)}"
def login_logic(password):
if password == ACCESS_PASSWORD:
return gr.update(visible=False), gr.update(visible=True), ""
return (
gr.update(visible=True),
gr.update(visible=False),
"<p style='color:red'>❌ Incorrect Password</p>",
)
# --- GRADIO BLOCKS LAYOUT ---
custom_css = "footer {visibility: hidden}"
with gr.Blocks(title="MITRP Policy Assistant") as app:
# --- SCREEN 1: LOGIN ---
with gr.Column(visible=True) as login_col:
gr.Markdown("## πŸ”’ MITRP Policy Bot\nPlease enter the access password to continue.")
with gr.Row():
pass_input = gr.Textbox(
label="Password",
type="password",
placeholder="Enter password...",
show_label=False,
scale=4,
)
login_btn = gr.Button("Login", variant="primary", scale=1)
error_msg = gr.Markdown("")
# --- SCREEN 2: CHAT ---
with gr.Column(visible=False) as chat_col:
gr.Markdown("## πŸ›οΈ MITRP Policy Assistant")
chat_interface = gr.ChatInterface(
fn=chat_logic,
examples=[
"How many papers should I write per year?",
"What is the vacation policy?",
"How do I connect to the GPU machines?",
],
)
# --- EVENT LISTENERS ---
login_btn.click(
fn=login_logic,
inputs=[pass_input],
outputs=[login_col, chat_col, error_msg],
)
pass_input.submit(
fn=login_logic,
inputs=[pass_input],
outputs=[login_col, chat_col, error_msg],
)
if __name__ == "__main__":
app.launch(theme="soft", css=custom_css)