Spaces:
Sleeping
Sleeping
| # --- agent.py --- | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from duckduckgo_search import DDGS | |
| import torch | |
| SYSTEM_PROMPT = """ | |
| You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. | |
| Then return only the answer without any explanation or formatting. | |
| Do not say 'Final answer' or anything else. Just output the raw answer string. | |
| """ | |
| def web_search(query: str, max_results: int = 3) -> list[str]: | |
| results = [] | |
| try: | |
| with DDGS() as ddgs: | |
| for r in ddgs.text(query, max_results=max_results): | |
| snippet = f"{r['title']}: {r['body']} (URL: {r['href']})" | |
| results.append(snippet) | |
| except Exception as e: | |
| results.append(f"[Web search error: {e}]") | |
| return results | |
| class GaiaAgent: | |
| def __init__(self, model_id="google/flan-t5-base"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(self.device) | |
| def __call__(self, question: str) -> tuple[str, str]: | |
| try: | |
| # Heuristik: gör webbsök om frågan kräver externa fakta | |
| search_required = any(keyword in question.lower() for keyword in [ | |
| "wikipedia", "who", "when", "where", "youtube", "mp3", "video", "article", "name", "code", "city", "award", "nasa" | |
| ]) | |
| if search_required: | |
| search_results = web_search(question) | |
| context = "\n".join(search_results) | |
| prompt = f"{SYSTEM_PROMPT}\n\nSearch context:\n{context}\n\nQuestion: {question}" | |
| trace = f"Search used:\n{context}" | |
| else: | |
| prompt = f"{SYSTEM_PROMPT}\n\n{question}" | |
| trace = "Search not used." | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| final = output_text.strip() | |
| return final, trace | |
| except Exception as e: | |
| return "ERROR", f"Agent failed: {e}" | |