Spaces:
Runtime error
Runtime error
| import re | |
| import requests | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from bs4 import BeautifulSoup | |
| import time | |
| import yaml | |
| import os | |
| # Load prompt template | |
| _PROMPT_TPL = None | |
| try: | |
| with open("prompts.yaml", "r") as f: | |
| _PROMPT_TPL = yaml.safe_load(f)["answer_prompt"] | |
| except Exception: | |
| _PROMPT_TPL = "Question: {question}\nContext: {context}\nAnswer:" | |
| class SimpleAgent: | |
| """ | |
| Lightweight agent: | |
| - uses a small seq2seq LLM (Flan-T5 small) to generate concise answers | |
| - uses a quick web retrieval (Wikipedia API and DuckDuckGo snippets) to build context | |
| - returns stripped answers ready for EXACT-MATCH scoring | |
| """ | |
| def __init__(self, model_name="google/flan-t5-small"): | |
| self.model_name = model_name | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model {model_name}: {e}") | |
| def _wiki_search(self, query, sentences=2): | |
| """Simple Wikipedia summary fetch""" | |
| try: | |
| url = "https://en.wikipedia.org/api/rest_v1/page/summary/" + requests.utils.quote(query) | |
| r = requests.get(url, timeout=6) | |
| if r.status_code == 200: | |
| data = r.json() | |
| return data.get("extract", "") | |
| except Exception: | |
| pass | |
| return "" | |
| def _duckduckgo_snippets(self, query, max_chunks=2): | |
| """ | |
| Lightweight scraping of DuckDuckGo HTML results for short snippets. | |
| (If blocked, it will fail gracefully β agent still works with model-only.) | |
| """ | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| url = f"https://duckduckgo.com/html/?q={requests.utils.quote(query)}" | |
| r = requests.get(url, headers=headers, timeout=6) | |
| soup = BeautifulSoup(r.text, "html.parser") | |
| snippets = [] | |
| for res in soup.select(".result__snippet")[:max_chunks]: | |
| snippets.append(res.get_text(separator=" ")) | |
| return " ".join(snippets) | |
| except Exception: | |
| return "" | |
| def _clean_answer(self, text): | |
| # Keep it compact: strip whitespace and newlines, remove leading/trailing punctuation | |
| if not text: | |
| return "" | |
| a = text.strip() | |
| # remove multiple newlines/spaces | |
| a = re.sub(r"\s+", " ", a) | |
| # remove leading and trailing quotes or dashes | |
| a = a.strip(" \n\"'`-:;") | |
| return a | |
| def _build_context(self, question): | |
| # Use quick heuristics to extract keywords and fetch context | |
| # Try wikipedia, then duckduckgo | |
| kw = question.split("?")[0][:120] # shorthand | |
| wiki = self._wiki_search(kw) | |
| ddg = self._duckduckgo_snippets(kw) | |
| context_parts = [p for p in [wiki, ddg] if p] | |
| return " ".join(context_parts)[:3000] # limit context length | |
| def answer(self, question): | |
| """ | |
| Return a single string: the final answer ONLY (no commentary). | |
| """ | |
| context = self._build_context(question) | |
| prompt = _PROMPT_TPL.format(question=question, context=context) | |
| # generate | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
| out = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) | |
| text = self.tokenizer.decode(out[0], skip_special_tokens=True) | |
| return self._clean_answer(text) | |