Spaces:
Sleeping
Sleeping
| """GAIA μ²¨λΆ νμΌ μ²λ¦¬ + μ§λ¬Έβtask_id μΈλ±μ€. | |
| CodeAgentμ μκ·Έλμ² μ μ½(__call__μ΄ questionλ§ λ°μ) λλ¬Έμ task_idλ₯Ό μ§μ | |
| μ£Όμ ν μ μμ΄, λͺ¨λ μ μ mutable 컨ν μ΄λ + prefetch μΈλ±μ€λ‘ μ°ννλ€. | |
| νλ¦: | |
| 1) BasicAgent.__init__ λ¨κ³μ prefetch_question_index() β /questions 1ν νΈμΆ | |
| ν΄μ {μ§λ¬Έλ³Έλ¬Έ: task_id} μ¬μ μ λ§λ€κ³ set_question_index() λ‘ λ±λ‘. | |
| 2) BasicAgent.__call__ μ§μ μ set_current_task(question) μΌλ‘ νμ¬ λ¬Έμ μ | |
| task_idμ μ§λ¬Έ λ³Έλ¬Έμ _CURRENT_TASK μ μΈν . | |
| 3) μμ΄μ νΈκ° get_attached_file() μ μΈμ μμ΄ νΈμΆνλ©΄ _CURRENT_TASK μ | |
| task_idλ‘ μ±μ μλ²μμ νμΌμ λ°μμ€κ³ , νμ λ³λ‘ μ²λ¦¬: | |
| - ν μ€νΈ/CSV/JSON/code: UTF-8 λμ½λ© | |
| - Excel(.xlsx): μνΈλ³ CSV | |
| - PDF: νμ΄μ§λ³ ν μ€νΈ μΆμΆ (pypdf) | |
| - μ΄λ―Έμ§: VLM(Qwen2.5-VL-7B)μΌλ‘ νμ¬ μ§λ¬Έ 컨ν μ€νΈμ λ§μΆ° λΆμ | |
| - μ€λμ€: Whisper(large-v3) μ μ¬ | |
| """ | |
| import io | |
| import re | |
| import requests | |
| from smolagents import tool | |
| # μ±μ μλ² URLμ μ¬κΈ°μλ ν λ² μ μ (app.pyμ λμΌ κ°). | |
| # tools λͺ¨λμ λ 립μ μΌλ‘ μ¬μ©νλλΌλ μλ―Έκ° ν΅νλλ‘ λΆλ¦¬ν΄ λλ€. | |
| _DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| # BasicAgent.__call__ μ§μ μ κ°±μ λλ mutable 컨ν μ΄λ. | |
| # questionμ μ΄λ―Έμ§ VLM νΈμΆ μ 컨ν μ€νΈ(prompt)λ‘ μ¬μ©λλ€. | |
| _CURRENT_TASK = {"id": None, "question": None} | |
| # question.strip() -> task_id μ¬μ . | |
| _QUESTION_INDEX: dict = {} | |
| def prefetch_question_index() -> dict: | |
| """μ±μ μλ² /questions λ₯Ό ν λ² νΈμΆν΄ {μ§λ¬Έλ³Έλ¬Έ: task_id} μ¬μ μ λΉλνλ€. | |
| μ€ν¨ν΄λ λΉ dictλ₯Ό λ°νν΄ μμ΄μ νΈκ° μ²¨λΆ μλ λ¬Έμ λ§μ΄λΌλ ν μ μκ² νλ€.""" | |
| try: | |
| r = requests.get(f"{_DEFAULT_API_URL}/questions", timeout=15) | |
| r.raise_for_status() | |
| idx = {} | |
| for item in r.json(): | |
| qt = (item.get("question") or "").strip() | |
| tid = item.get("task_id") | |
| if qt and tid: | |
| if qt in idx and idx[qt] != tid: | |
| print( | |
| "Warning: duplicate question text in prefetch index β " | |
| f"task_id {idx[qt]!r} will be overwritten by {tid!r}" | |
| ) | |
| idx[qt] = tid | |
| return idx | |
| except Exception as e: | |
| print(f"Warning: could not prefetch question index: {e}") | |
| return {} | |
| def set_question_index(idx: dict) -> None: | |
| """BasicAgent.__init__μμ prefetch κ²°κ³Όλ₯Ό λͺ¨λ μ μμ λ°μμ£Όλ μΈν°.""" | |
| global _QUESTION_INDEX | |
| _QUESTION_INDEX = idx | |
| def set_current_task(question: str): | |
| """BasicAgent.__call__ μ§μ μ νμ¬ λ¬Έμ μ task_idμ μ§λ¬Έ λ³Έλ¬Έμ λͺ¨λ μ μμ μΈν . | |
| μ§λ¬Έ λ³Έλ¬Έμ μ΄λ―Έμ§ 첨λΆμ VLM νΈμΆμ prompt 컨ν μ€νΈλ‘ μ λ¬λλ€. | |
| λ§€μΉ μ€ν¨ μ task_idλ‘ Noneμ΄ λ€μ΄κ°μ§λ§ questionμ κ·Έλλ‘ μ μ₯λλ€.""" | |
| tid = _QUESTION_INDEX.get(question.strip()) | |
| _CURRENT_TASK["id"] = tid | |
| _CURRENT_TASK["question"] = question | |
| return tid | |
| # --- νμΌ νμ λΆκΈ° ν¬νΌ --- | |
| def _extract_filename(headers, url: str) -> str: | |
| """Content-Disposition ν€λμμ filenameμ λ½κ±°λ, URL λλΆλΆμΌλ‘ ν΄λ°±. | |
| μ±μ μλ²κ° Content-Typeμ octet-streamμΌλ‘ μ€ λ νμ₯μλ‘ λ³΄κ°νκΈ° μν¨.""" | |
| cd = headers.get("Content-Disposition", "") | |
| # filename* (RFC 5987) μ filename= μμͺ½ λ€ μ²λ¦¬. | |
| m = re.search(r'filename\*?=(?:UTF-8\'\')?"?([^";\r\n]+)"?', cd) | |
| if m: | |
| return m.group(1).strip().strip('"') | |
| return url.rsplit("/", 1)[-1] | |
| def _is_excel(content_type: str, ext: str) -> bool: | |
| if ext in ("xlsx", "xls"): | |
| return True | |
| ct = content_type.lower() | |
| return "spreadsheet" in ct or ct.endswith("xlsx") or ct.endswith("xls") or "excel" in ct | |
| def _is_pdf(content_type: str, ext: str) -> bool: | |
| return ext == "pdf" or "pdf" in content_type.lower() | |
| def _is_image(content_type: str, ext: str) -> bool: | |
| return ext in ("png", "jpg", "jpeg", "webp", "gif", "bmp") \ | |
| or content_type.lower().startswith("image/") | |
| def _is_audio(content_type: str, ext: str) -> bool: | |
| return ext in ("mp3", "wav", "m4a", "ogg", "flac") \ | |
| or content_type.lower().startswith("audio/") | |
| # --- νμ λ³ νΈλ€λ¬ --- | |
| def _handle_excel(content: bytes, content_type: str) -> str: | |
| """xlsx β μνΈλ³ CSVλ‘ μ§λ ¬ν. GAIAμ λ§€μΆ/νλ§€ λ°μ΄ν° λ¬Έμ κ° μμ£Ό λμ¨λ€.""" | |
| try: | |
| import pandas as _pd | |
| bio = io.BytesIO(content) | |
| sheets = _pd.read_excel(bio, sheet_name=None) | |
| parts = [] | |
| for name, df in sheets.items(): | |
| parts.append(f"--- Sheet: {name} ---\n{df.to_csv(index=False)}") | |
| combined = "\n\n".join(parts) | |
| if len(combined) > 12000: | |
| combined = combined[:12000] + "\n...[truncated]" | |
| return f"[Content-Type: {content_type}]\n{combined}" | |
| except Exception as e: | |
| return f"Excel parse error: {e}" | |
| def _handle_pdf(content: bytes, content_type: str) -> str: | |
| """pypdfλ‘ PDF λ³Έλ¬Έ ν μ€νΈ μΆμΆ. νμ΄μ§λ³λ‘ ꡬλΆν΄μ λ°ν. | |
| μ€μΊ PDF(μ΄λ―Έμ§λ‘ λ)λ ν μ€νΈκ° λΉκ±°λ κΉ¨μ§ μ μλλ°, κ·Έ κ²½μ°λ | |
| LLMμ΄ μν€/μΉκ²μμΌλ‘ ν΄λ°±νλλ‘ μμ€ν ν둬ννΈκ° μ λνλ€.""" | |
| try: | |
| from pypdf import PdfReader | |
| bio = io.BytesIO(content) | |
| reader = PdfReader(bio) | |
| parts = [] | |
| for i, page in enumerate(reader.pages): | |
| try: | |
| txt = page.extract_text() or "" | |
| except Exception as pe: | |
| txt = f"(extraction failed: {pe})" | |
| parts.append(f"--- Page {i+1} ---\n{txt}") | |
| combined = "\n\n".join(parts) | |
| if len(combined) > 12000: | |
| combined = combined[:12000] + "\n...[truncated]" | |
| return f"[PDF, {len(reader.pages)} pages, Content-Type: {content_type}]\n{combined}" | |
| except Exception as e: | |
| return f"PDF parse error: {e}" | |
| def _handle_image(content: bytes, content_type: str) -> str: | |
| """VLM(Qwen2.5-VL-7B)μΌλ‘ νμ¬ μ§λ¬Έ 컨ν μ€νΈμ λ§μΆ° μ΄λ―Έμ§λ₯Ό λΆμνλ€. | |
| HF Inference APIμ OpenAI νΈν chat_completionμΌλ‘ base64 data URLμ μ μ‘νλ€. | |
| μ§λ¬Έ 컨ν μ€νΈκ° μμΌλ©΄ κ·Έκ±Έ κ·Έλλ‘ promptμ λ°μ μ λ΅μ μ§μ λμμ΄ λλ | |
| λΆλΆλ§ λ½μλ΄λλ‘ μ λ(generic μΊ‘μ μ λν μΌμ λμΉ¨). νΈμΆ μ€ν¨ μ μλ¬ | |
| λ¬Έμμ΄μ λ°νν΄ μμ΄μ νΈκ° λ€λ₯Έ μ λ΅μΌλ‘ ν΄λ°±ν μ μκ² νλ€. | |
| HF_TOKEN νκ²½λ³μκ° νμνλ€. Space λ°°ν¬ μμλ Space secretsμ λ±λ‘ν΄μΌ ν¨. | |
| """ | |
| try: | |
| import base64 | |
| from huggingface_hub import InferenceClient | |
| question = (_CURRENT_TASK.get("question") or "").strip() | |
| # λ°μ΄ν° URL ꡬμ±. content_typeμ΄ image/* κ° μλ μλ μμ΄ μμ νκ² ν΄λ°±. | |
| mime = content_type.split(";")[0].strip() | |
| if not mime.startswith("image/"): | |
| mime = "image/png" | |
| b64 = base64.b64encode(content).decode("utf-8") | |
| data_url = f"data:{mime};base64,{b64}" | |
| if question: | |
| prompt = ( | |
| "Analyze the attached image and answer the following question. " | |
| "Read any text, numbers, or labels visible in the image. " | |
| "If it is a chart or table, extract the relevant data values precisely.\n\n" | |
| f"Question: {question}" | |
| ) | |
| else: | |
| prompt = ( | |
| "Describe the attached image in detail, including any visible text, " | |
| "numbers, or labels." | |
| ) | |
| client = InferenceClient(provider="auto") # HF_TOKEN νκ²½λ³μ μ¬μ© | |
| resp = client.chat_completion( | |
| model="Qwen/Qwen2.5-VL-7B-Instruct", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": data_url}}, | |
| ], | |
| } | |
| ], | |
| max_tokens=1024, | |
| ) | |
| analysis = resp.choices[0].message.content | |
| return ( | |
| f"[Image analysis (Content-Type: {content_type}, {len(content)} bytes)]\n" | |
| f"{analysis}" | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Image attached (Content-Type: {content_type}, {len(content)} bytes). " | |
| f"VLM analysis failed: {e}" | |
| ) | |
| def _handle_audio(content: bytes, content_type: str) -> str: | |
| """Whisper(large-v3)λ‘ μ€λμ€ μ μ¬. GAIA μ€λμ€λ λ³΄ν΅ μ§§μ λ°νλΌ ν λ² νΈμΆλ‘ μΆ©λΆ. | |
| HF_TOKEN νκ²½λ³μκ° νμνλ€. Space λ°°ν¬ μμλ Space secretsμ λ±λ‘ν΄μΌ ν¨. | |
| """ | |
| try: | |
| from huggingface_hub import InferenceClient | |
| client = InferenceClient(provider="auto") | |
| result = client.automatic_speech_recognition( | |
| audio=content, | |
| model="openai/whisper-large-v3", | |
| ) | |
| # huggingface_hub λ²μ μ λ°λΌ dict λλ dataclass-like κ°μ²΄λ‘ λ°νλλ―λ‘ | |
| # μμͺ½ λͺ¨λ μ²λ¦¬νλ€. | |
| if hasattr(result, "text"): | |
| transcription = result.text | |
| elif isinstance(result, dict): | |
| transcription = result.get("text", str(result)) | |
| else: | |
| transcription = str(result) | |
| return ( | |
| f"[Audio transcription (Content-Type: {content_type}, {len(content)} bytes)]\n" | |
| f"{transcription}" | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Audio attached (Content-Type: {content_type}, {len(content)} bytes). " | |
| f"Transcription failed: {e}" | |
| ) | |
| def get_attached_file() -> str: | |
| """Download the file attached to the CURRENT GAIA task and return its content. | |
| Takes no arguments β the current task_id is auto-resolved from the question. | |
| Use this whenever the question references a file, spreadsheet, image, audio, PDF, code listing, | |
| CSV, or any external resource. Returns: | |
| - Text/CSV/JSON/code: the decoded text (truncated to ~12k chars). | |
| - Excel (.xlsx): each sheet rendered as CSV (truncated). | |
| - PDF: extracted text per page (truncated). | |
| - Image (PNG/JPEG/WEBP/GIF/BMP): a vision-language model analysis focused on the current question. | |
| - Audio (MP3/WAV/M4A/OGG/FLAC): a Whisper transcription. | |
| - Other binary: a metadata description (size + content-type). | |
| """ | |
| # μκ·Έλμ² μ μ½ λλ¬Έμ task_id μΈμλ₯Ό λ°μ§ μκ³ , λͺ¨λ μ μ _CURRENT_TASK μμ κ°μ Έμ¨λ€. | |
| # μ΄ κ°μ BasicAgent.__call__ μ§μ μ set_current_task()λ‘ μΈν λλ€. | |
| task_id = _CURRENT_TASK.get("id") | |
| if not task_id: | |
| return "No task context available β likely no file attached for this question." | |
| try: | |
| url = f"{_DEFAULT_API_URL}/files/{task_id}" | |
| r = requests.get(url, timeout=30) | |
| if r.status_code == 404: | |
| return "No file attached to this task." | |
| r.raise_for_status() | |
| content_type = r.headers.get("Content-Type", "") | |
| filename = _extract_filename(r.headers, url) | |
| ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" | |
| # 1) λͺ νν λ°μ΄λ리 νμ μ λ¨Όμ μ²λ¦¬. | |
| # μΌλΆ PDF/SVGλ UTF-8 decodeκ° λμ΄λ μμ ν μ€νΈλ‘ λ°ννλ©΄ νμ§μ΄ ν¬κ² λ¨μ΄μ§λ€. | |
| if _is_excel(content_type, ext): | |
| return _handle_excel(r.content, content_type) | |
| if _is_pdf(content_type, ext): | |
| return _handle_pdf(r.content, content_type) | |
| if _is_image(content_type, ext): | |
| return _handle_image(r.content, content_type) | |
| if _is_audio(content_type, ext): | |
| return _handle_audio(r.content, content_type) | |
| # 2) ν μ€νΈ κ³μ΄μ΄λ©΄ UTF-8λ‘ λ°ν. | |
| try: | |
| text = r.content.decode("utf-8") | |
| if len(text) > 12000: | |
| text = text[:12000] + "\n...[truncated]" | |
| return f"[Content-Type: {content_type}]\n{text}" | |
| except UnicodeDecodeError: | |
| pass | |
| # 3) μ μ μλ λ°μ΄λ리 β λ©νλ°μ΄ν°λ§ λ°ν. | |
| return ( | |
| f"Binary file (Content-Type: {content_type}, " | |
| f"size: {len(r.content)} bytes). Cannot display as text. " | |
| f"URL: {url}" | |
| ) | |
| except Exception as e: | |
| return f"get_attached_file error: {e}" | |