stochastic / agent.py
Sonu Prasad
initial commit
822c114
from typing import Optional
from dataclasses import dataclass
import openai
import google.generativeai as genai
from config import config
from vector_store import vector_store, SearchResult
SYSTEM_PROMPT = """You are a research assistant that answers questions based on the provided document context.
Instructions:
1. Answer ONLY based on the provided context
2. If the context doesn't contain enough information, say so
3. Cite sources by mentioning the paper name and section
4. Be concise but thorough
5. Format responses with markdown for readability"""
@dataclass
class QueryResponse:
answer: str
sources: list[dict]
arxiv_fetched: Optional[str] = None
class DocumentAgent:
def __init__(self):
self.conversation_history: list[dict] = []
if config.GEMINI_API_KEY:
genai.configure(api_key=config.GEMINI_API_KEY)
def query(self, question: str, paper_filter: Optional[str] = None) -> QueryResponse:
results = vector_store.search(question, paper_filter=paper_filter)
if not results:
return QueryResponse(
answer="I don't have any documents indexed yet. Please upload a PDF first.",
sources=[]
)
context = self._build_context(results)
answer = self._generate_answer(question, context)
sources = self._format_sources(results)
return QueryResponse(answer=answer, sources=sources)
def _build_context(self, results: list[SearchResult]) -> str:
context_parts = []
for r in results:
context_parts.append(
f"[Source: {r.chunk.paper_name} | Section: {r.chunk.section_title}]\n{r.chunk.content}"
)
return "\n\n---\n\n".join(context_parts)
def _generate_answer(self, question: str, context: str) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
]
if config.OPENROUTER_API_KEY:
try:
client = openai.OpenAI(
api_key=config.OPENROUTER_API_KEY,
base_url=config.OPENROUTER_BASE_URL
)
response = client.chat.completions.create(
model=config.OPENROUTER_MODEL,
messages=messages,
max_tokens=1500,
temperature=0.3
)
return response.choices[0].message.content
except Exception:
pass
if config.GEMINI_API_KEY:
try:
model = genai.GenerativeModel(config.GEMINI_MODEL)
prompt = f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {question}"
response = model.generate_content(prompt)
return response.text
except Exception:
pass
return "Unable to generate response. Please check API configuration."
def _format_sources(self, results: list[SearchResult]) -> list[dict]:
seen = set()
sources = []
for r in results:
key = f"{r.chunk.paper_name}:{r.chunk.section_title}"
if key not in seen:
seen.add(key)
sources.append({
"paper": r.chunk.paper_name,
"section": r.chunk.section_title or "Content"
})
return sources
def clear_history(self):
self.conversation_history = []
agent = DocumentAgent()