Spaces:
Sleeping
Sleeping
| """ | |
| Base Agent - Abstract base class cho tất cả agents trong hệ thống. | |
| Cung cấp shared interface và utility methods. | |
| """ | |
| import json | |
| import re | |
| from abc import ABC, abstractmethod | |
| from typing import Optional, Union | |
| from langchain_core.messages import HumanMessage, SystemMessage # type: ignore[import-untyped] | |
| from config.settings import get_llm_for_agent | |
| from utils.logger import get_logger, log_agent_step, log_agent_io | |
| class BaseAgent(ABC): | |
| """ | |
| Abstract base class cho mọi agent. | |
| Mỗi agent kế thừa class này và implement method `run()`. | |
| Kiến trúc pipeline: | |
| Agent 1 (QueryAgent / Groq): | |
| - Sinh queries + crawl dữ liệu thô | |
| - KHÔNG phán xét, KHÔNG phân loại claim | |
| Agent 2 (ExtractorAgent / Gemini): | |
| - Đọc toàn bộ crawled data + tóm tắt từng bài | |
| - KHÔNG lọc, KHÔNG suy luận về claim | |
| Agent 3 (ReasoningAgent / GPT-4o-mini): | |
| - Lọc bài liên quan / reject bài không liên quan | |
| - Tự đánh giá khả năng kiểm chứng của claim | |
| - Suy luận → verdict THẬT / GIẢ / CHƯA XÁC ĐỊNH | |
| """ | |
| # Giới hạn ký tự tối đa cho prompt (mặc định cho Groq free tier) | |
| # ~4 chars ≈ 1 token. Agent 2 (Gemini) sẽ override thành giá trị lớn hơn. | |
| DEFAULT_MAX_PROMPT_CHARS = 12000 | |
| def __init__(self, name: str, agent_type_config: str = "AGENT1", max_prompt_chars: Optional[int] = None): | |
| """ | |
| Khởi tạo agent. | |
| Args: | |
| name: Tên agent (VD: 'QueryAgent', 'ExtractorAgent') | |
| agent_type_config: Prefix cấu hình trong .env (VD: 'AGENT1', 'AGENT2') | |
| max_prompt_chars: Giới hạn ký tự tối đa cho prompt. None = dùng default 12000. | |
| """ | |
| self.name = name | |
| self.logger = get_logger(f"Agent.{name}") | |
| self.llm = get_llm_for_agent(agent_type_config) | |
| self.max_prompt_chars = max_prompt_chars or self.DEFAULT_MAX_PROMPT_CHARS | |
| def run(self, state: dict) -> dict: | |
| """ | |
| Chạy agent với state hiện tại và trả về state đã cập nhật. | |
| Mỗi agent con phải implement method này. | |
| Args: | |
| state: Shared state dictionary chứa toàn bộ dữ liệu pipeline | |
| Returns: | |
| Updated state dictionary | |
| """ | |
| pass | |
| def call_llm(self, system_prompt: str, user_prompt: str) -> str: | |
| """ | |
| Gọi LLM với system prompt và user prompt. | |
| Có cơ chế xoay vòng API Key thông minh: | |
| - Vòng 1: Thử lần lượt TẤT CẢ key trong pool (không ngủ) | |
| - Nếu tất cả key đều bị 429 → Ngủ đúng 65s (chờ Google reset bucket) | |
| - Vòng 2: Thử lại TẤT CẢ key 1 lần nữa (key đã được reset) | |
| - Tối đa 3 vòng ngủ (3 × 65s = ~3 phút) trước khi bỏ cuộc | |
| """ | |
| # Truncate prompt nếu quá dài (tránh Groq TPM limit) | |
| if len(user_prompt) > self.max_prompt_chars: | |
| truncated_len = len(user_prompt) | |
| user_prompt = user_prompt[:self.max_prompt_chars] + \ | |
| f"\n\n[... Đã cắt bớt {truncated_len - self.max_prompt_chars} ký tự để tránh vượt giới hạn API ...]" | |
| self.logger.warning( | |
| f"[{self.name}] Prompt truncated: {truncated_len} → {self.max_prompt_chars} chars" | |
| ) | |
| log_agent_step(self.logger, self.name, "Calling LLM") | |
| # Sanitize text to remove surrogate characters and \x00 which crash the OpenAI parser (400 Bad Request) | |
| user_prompt = user_prompt.replace('\x00', '').encode('utf-8', 'replace').decode('utf-8') | |
| system_prompt = system_prompt.replace('\x00', '').encode('utf-8', 'replace').decode('utf-8') | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt), | |
| ] | |
| import time as _time | |
| import os | |
| from config.settings import get_next_gemini_key, get_next_groq_key | |
| # Xác định provider hiện tại | |
| model_name_str = str(getattr(self.llm, "model_name", getattr(self.llm, "model", ""))).lower() | |
| base_url_str_check = str(getattr(self.llm, "openai_api_base", getattr(self.llm, "base_url", ""))).lower() | |
| is_gemini = "gemini" in model_name_str or "google" in base_url_str_check | |
| is_groq = "groq" in base_url_str_check | |
| is_hf = "huggingface" in base_url_str_check | |
| is_openrouter = "openrouter" in base_url_str_check | |
| is_openai = "api.openai.com" in base_url_str_check | |
| # Load pool keys | |
| if is_gemini: | |
| pool_keys_str = os.getenv("GEMINI_POOL_KEYS", "") | |
| get_next_key = get_next_gemini_key | |
| provider_name = "Gemini" | |
| elif is_groq: | |
| pool_keys_str = os.getenv("GROQ_POOL_KEYS", "") | |
| get_next_key = get_next_groq_key | |
| provider_name = "Groq" | |
| elif is_hf: | |
| from config.settings import get_next_hf_key | |
| pool_keys_str = os.getenv("HF_POOL_KEYS", "") | |
| get_next_key = get_next_hf_key | |
| provider_name = "HuggingFace" | |
| elif is_openrouter: | |
| from config.settings import get_next_openrouter_key | |
| pool_keys_str = os.getenv("OPENROUTER_POOL_KEYS", "") | |
| get_next_key = get_next_openrouter_key | |
| provider_name = "OpenRouter" | |
| elif is_openai: | |
| from config.settings import get_next_openai_key | |
| pool_keys_str = os.getenv("OPENAI_POOL_KEYS", "") | |
| get_next_key = get_next_openai_key | |
| provider_name = "OpenAI" | |
| else: | |
| pool_keys_str = "" | |
| get_next_key = lambda: None | |
| provider_name = "Unknown" | |
| pool_keys = [k.strip() for k in pool_keys_str.split(",") if k.strip()] | |
| pool_size = len(pool_keys) if pool_keys else 1 | |
| # DEBUG: In ra để biết chính xác đang dùng model/provider nào | |
| base_url_str = str(getattr(self.llm, "openai_api_base", getattr(self.llm, "base_url", "N/A"))) | |
| self.logger.info( | |
| f"[{self.name}] 🔧 DEBUG: model={model_name_str} | provider={provider_name} | " | |
| f"base_url={base_url_str} | pool_size={pool_size}" | |
| ) | |
| # === CHIẾN LƯỢC XOAY VÒNG === | |
| # Tối đa 3 "sleep cycles". Mỗi cycle: | |
| # 1) Thử tất cả key trong pool (mỗi key thử 1 lần) | |
| # 2) Nếu hết key mà vẫn 429 → ngủ 65s rồi thử vòng mới | |
| MAX_SLEEP_CYCLES = 3 | |
| last_error = None | |
| for cycle in range(MAX_SLEEP_CYCLES + 1): | |
| # Nếu đây là cycle > 0 (tức là đã thất bại ở vòng trước), ngủ 65s | |
| if cycle > 0: | |
| self.logger.warning( | |
| f"[{self.name}] 💤 Ngủ đông 65s — chờ {provider_name} reset rate-limit bucket " | |
| f"(Vòng ngủ {cycle}/{MAX_SLEEP_CYCLES})" | |
| ) | |
| _time.sleep(65) | |
| # Thử tất cả key trong pool | |
| for key_idx in range(pool_size): | |
| # Lấy key tiếp theo từ pool (round-robin) | |
| new_key = get_next_key() | |
| if new_key: | |
| from langchain_openai import ChatOpenAI # type: ignore[import-untyped] | |
| self.llm = ChatOpenAI( | |
| model=getattr(self.llm, "model_name", getattr(self.llm, "model", "")), | |
| temperature=getattr(self.llm, "temperature", 0.1), | |
| max_tokens=getattr(self.llm, "max_tokens", 4096), | |
| base_url=getattr(self.llm, "openai_api_base", getattr(self.llm, "base_url", None)), | |
| api_key=new_key, | |
| ) | |
| try: | |
| # BẬT JSON MODE: Đảm bảo model 8B nhỏ không sinh văn bản rác làm gãy JSON Parser | |
| # Sửa lỗi Groq 400: Chỉ bật nếu prompt có từ khoá "json" | |
| needs_json = "json" in system_prompt.lower() or "json" in user_prompt.lower() | |
| if self.name == "QueryAgent" and (is_groq or is_openai) and needs_json: | |
| response = self.llm.bind(response_format={"type": "json_object"}).invoke(messages) | |
| else: | |
| response = self.llm.invoke(messages) | |
| result = str(response.content) | |
| if cycle > 0 or key_idx > 0: | |
| self.logger.info( | |
| f"[{self.name}] ✅ Thành công với {provider_name} key #{key_idx + 1} " | |
| f"(sau {cycle} vòng ngủ)" | |
| ) | |
| log_agent_io(self.logger, self.name, user_prompt, result) | |
| return result | |
| except Exception as e: | |
| last_error = e | |
| error_str = str(e).lower() | |
| is_rate_limit = any(kw in error_str for kw in | |
| ["rate_limit", "429", "413", "exhausted", "resource_exhausted", "402", "depleted", "503", "unavailable", "high demand"]) | |
| if is_rate_limit: | |
| err_snippet = str(e)[:200] | |
| self.logger.warning( | |
| f"[{self.name}] ⚠️ Key #{key_idx + 1}/{pool_size} bị 429 " | |
| f"(cycle {cycle}) — thử key tiếp... | Lỗi: {err_snippet}" | |
| ) | |
| _time.sleep(1) | |
| continue | |
| else: | |
| raise | |
| # Nếu đã ở cycle cuối mà vẫn không thành công | |
| if cycle == MAX_SLEEP_CYCLES: | |
| self.logger.error( | |
| f"[{self.name}] ❌ Đã thử {pool_size} keys × {MAX_SLEEP_CYCLES + 1} vòng " | |
| f"= {pool_size * (MAX_SLEEP_CYCLES + 1)} lần gọi mà vẫn bị Rate Limit. " | |
| f"Bỏ cuộc." | |
| ) | |
| if last_error: | |
| raise last_error | |
| else: | |
| raise RuntimeError("No API keys available or all failed without exception.") | |
| def _clean_json(text: str) -> str: | |
| """Sửa các lỗi JSON phổ biến từ LLM (trailing comma, text thừa).""" | |
| text = re.sub(r",\s*([}\]])", r"\1", text) | |
| return text.strip() | |
| def parse_json_response(self, response: str) -> dict: | |
| """ | |
| Parse JSON từ LLM response. Xử lý trường hợp LLM trả về | |
| markdown code block, text thừa quanh JSON, hoặc trailing commas. | |
| Args: | |
| response: Raw text response từ LLM | |
| Returns: | |
| Parsed dictionary | |
| """ | |
| # Thử parse trực tiếp | |
| try: | |
| return json.loads(response) | |
| except json.JSONDecodeError: | |
| pass | |
| # Thử tìm JSON trong markdown code block | |
| json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", response, re.DOTALL) | |
| if json_match: | |
| cleaned = self._clean_json(json_match.group(1)) | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| pass | |
| # Thử tìm JSON object trong text (greedy) | |
| json_match = re.search(r"\{.*\}", response, re.DOTALL) | |
| if json_match: | |
| cleaned = self._clean_json(json_match.group(0)) | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| pass | |
| # Thử sửa JSON bị cắt giữa chừng (Gemini hay bị) | |
| json_match = re.search(r"\{.*", response, re.DOTALL) | |
| if json_match: | |
| truncated = self._clean_json(json_match.group(0)) | |
| repaired = self._repair_truncated_json(truncated) | |
| if repaired: | |
| return repaired | |
| # Nếu không parse được, trả về raw text trong dict | |
| self.logger.warning(f"[{self.name}] Could not parse JSON from LLM response") | |
| return {"raw_response": response, "parse_error": True} | |
| def _repair_truncated_json(text: str) -> Optional[dict]: | |
| """Thử sửa JSON bị cắt giữa chừng bằng cách đóng ngoặc thiếu.""" | |
| open_braces = text.count("{") - text.count("}") | |
| open_brackets = text.count("[") - text.count("]") | |
| if open_braces <= 0 and open_brackets <= 0: | |
| return None | |
| # Cắt bỏ value dở dang cuối cùng (sau dấu , hoặc : cuối) | |
| last_comma = max(text.rfind(","), text.rfind(":")) | |
| if last_comma > text.rfind("}") and last_comma > text.rfind("]"): | |
| text = text[:last_comma] | |
| # Đóng ngoặc thiếu | |
| text += "]" * max(0, open_brackets) + "}" * max(0, open_braces) | |
| try: | |
| import json as _json | |
| return _json.loads(text) | |
| except Exception: | |
| return None | |
| def add_log(self, state: dict, message: str) -> None: | |
| """ | |
| Thêm log message vào state để tracking pipeline. | |
| Args: | |
| state: Current state dictionary | |
| message: Log message | |
| """ | |
| if "agent_logs" not in state: | |
| state["agent_logs"] = [] | |
| state["agent_logs"].append(f"[{self.name}] {message}") |