akhilamarchela0987's picture
Create agent.py
30de503 verified
import re
import requests
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from bs4 import BeautifulSoup
import time
import yaml
import os
# Load prompt template
_PROMPT_TPL = None
try:
with open("prompts.yaml", "r") as f:
_PROMPT_TPL = yaml.safe_load(f)["answer_prompt"]
except Exception:
_PROMPT_TPL = "Question: {question}\nContext: {context}\nAnswer:"
class SimpleAgent:
"""
Lightweight agent:
- uses a small seq2seq LLM (Flan-T5 small) to generate concise answers
- uses a quick web retrieval (Wikipedia API and DuckDuckGo snippets) to build context
- returns stripped answers ready for EXACT-MATCH scoring
"""
def __init__(self, model_name="google/flan-t5-small"):
self.model_name = model_name
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {e}")
def _wiki_search(self, query, sentences=2):
"""Simple Wikipedia summary fetch"""
try:
url = "https://en.wikipedia.org/api/rest_v1/page/summary/" + requests.utils.quote(query)
r = requests.get(url, timeout=6)
if r.status_code == 200:
data = r.json()
return data.get("extract", "")
except Exception:
pass
return ""
def _duckduckgo_snippets(self, query, max_chunks=2):
"""
Lightweight scraping of DuckDuckGo HTML results for short snippets.
(If blocked, it will fail gracefully β€” agent still works with model-only.)
"""
try:
headers = {"User-Agent": "Mozilla/5.0"}
url = f"https://duckduckgo.com/html/?q={requests.utils.quote(query)}"
r = requests.get(url, headers=headers, timeout=6)
soup = BeautifulSoup(r.text, "html.parser")
snippets = []
for res in soup.select(".result__snippet")[:max_chunks]:
snippets.append(res.get_text(separator=" "))
return " ".join(snippets)
except Exception:
return ""
def _clean_answer(self, text):
# Keep it compact: strip whitespace and newlines, remove leading/trailing punctuation
if not text:
return ""
a = text.strip()
# remove multiple newlines/spaces
a = re.sub(r"\s+", " ", a)
# remove leading and trailing quotes or dashes
a = a.strip(" \n\"'`-:;")
return a
def _build_context(self, question):
# Use quick heuristics to extract keywords and fetch context
# Try wikipedia, then duckduckgo
kw = question.split("?")[0][:120] # shorthand
wiki = self._wiki_search(kw)
ddg = self._duckduckgo_snippets(kw)
context_parts = [p for p in [wiki, ddg] if p]
return " ".join(context_parts)[:3000] # limit context length
def answer(self, question):
"""
Return a single string: the final answer ONLY (no commentary).
"""
context = self._build_context(question)
prompt = _PROMPT_TPL.format(question=question, context=context)
# generate
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
out = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
text = self.tokenizer.decode(out[0], skip_special_tokens=True)
return self._clean_answer(text)