| import os |
| import json |
| import re |
| import torch |
| from typing import Dict, Optional |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
| CACHE_FILE = "gaia_answers_cache.json" |
| DEFAULT_MODEL = "google/flan-t5-base" |
|
|
| class EnhancedGAIAAgent: |
| """Агент для Hugging Face GAIA с улучшенной обработкой вопросов""" |
| |
| def __init__(self, model_name=DEFAULT_MODEL, use_cache=False): |
| print(f"Initializing EnhancedGAIAAgent with model: {model_name}") |
| self.model_name = model_name |
| self.use_cache = use_cache |
| self.cache = self._load_cache() if use_cache else {} |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| |
| def _load_cache(self) -> Dict[str, str]: |
| if os.path.exists(CACHE_FILE): |
| try: |
| with open(CACHE_FILE, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| except: |
| return {} |
| return {} |
| |
| def _save_cache(self) -> None: |
| try: |
| with open(CACHE_FILE, 'w', encoding='utf-8') as f: |
| json.dump(self.cache, f, ensure_ascii=False, indent=2) |
| except: |
| pass |
| |
| def _classify_question(self, question: str) -> str: |
| question_lower = question.lower() |
| |
| if any(word in question_lower for word in ["calculate", "sum", "how many"]): |
| return "calculation" |
| elif any(word in question_lower for word in ["list", "enumerate"]): |
| return "list" |
| elif any(word in question_lower for word in ["date", "time", "when"]): |
| return "date_time" |
| return "factual" |
| |
| def _format_answer(self, raw_answer: str, question_type: str) -> str: |
| answer = raw_answer.strip() |
| |
| |
| prefixes = ["Answer:", "The answer is:", "I think", "I believe"] |
| for prefix in prefixes: |
| if answer.lower().startswith(prefix.lower()): |
| answer = answer[len(prefix):].strip() |
| |
| |
| if question_type == "calculation": |
| numbers = re.findall(r'-?\d+\.?\d*', answer) |
| if numbers: |
| answer = numbers[0] |
| elif question_type == "list": |
| if "," not in answer and " " in answer: |
| items = [item.strip() for item in answer.split() if item.strip()] |
| answer = ", ".join(items) |
| |
| |
| answer = answer.strip('"\'') |
| if answer.endswith('.') and not re.match(r'.*\d\.$', answer): |
| answer = answer[:-1] |
| return re.sub(r'\s+', ' ', answer).strip() |
| |
| def __call__(self, question: str, task_id: Optional[str] = None) -> str: |
| cache_key = task_id if task_id else question |
| if self.use_cache and cache_key in self.cache: |
| return self.cache[cache_key] |
| |
| question_type = self._classify_question(question) |
| |
| try: |
| |
| inputs = self.tokenizer(question, return_tensors="pt") |
| outputs = self.model.generate(**inputs, max_length=100) |
| raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| formatted_answer = self._format_answer(raw_answer, question_type) |
| |
| |
| result = {"final_answer": formatted_answer} |
| json_response = json.dumps(result) |
| |
| if self.use_cache: |
| self.cache[cache_key] = json_response |
| self._save_cache() |
| |
| return json_response |
| |
| except Exception as e: |
| return json.dumps({"final_answer": f"AGENT ERROR: {e}"}) |