from __future__ import annotations import json import os from pathlib import Path from typing import Optional from urllib.parse import quote import pandas as pd import requests from datasets import load_dataset from huggingface_hub import hf_hub_download class TaskFileTool: def __init__(self, api_base_url: str, cache_dir: str = "task_files", timeout: int = 30): self.api_base_url = api_base_url.rstrip("/") self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.timeout = timeout self._task_to_file_path: dict[str, str] | None = None def _ensure_gaia_index(self) -> None: if self._task_to_file_path is not None: return print("[TaskFileTool] Loading GAIA validation split for file-path index...") self._task_to_file_path = {} try: ds = load_dataset( "gaia-benchmark/GAIA", "2023_level1", split="validation", ) for ex in ds: ex_dict = dict(ex) task_id = str(ex_dict.get("task_id", "")).strip() file_path = str(ex_dict.get("file_path", "")).strip() if task_id and file_path: self._task_to_file_path[task_id] = file_path print(f"[TaskFileTool] Indexed {len(self._task_to_file_path)} task_id -> file_path rows") except Exception as e: print(f"[TaskFileTool] Failed to build GAIA index: {e}") self._task_to_file_path = {} def get_task_context(self, task_id: str) -> str: file_path = self.download_task_file(task_id=task_id) if file_path is None: return "" return self.read_file_as_text(file_path) def download_task_file( self, task_id: Optional[str] = None, file_name: Optional[str] = None, ) -> Optional[Path]: """ Strategy: 1. Try scorer endpoints 2. Fallback to GAIA dataset file_path via hf_hub_download """ # 0) If file_name already exists in cache, use it immediately if file_name: local_candidate = self.cache_dir / Path(str(file_name)).name if local_candidate.exists(): print(f"[download_task_file] using cached file -> {local_candidate}") return local_candidate # 1) Try scorer API first candidates: list[str] = [] if file_name: candidates.append(str(file_name)) candidates.append(Path(str(file_name)).name) if task_id: candidates.append(str(task_id)) seen: set[str] = set() for ident in candidates: ident = str(ident).strip() if not ident or ident in seen: continue seen.add(ident) urls = [ f"{self.api_base_url}/files/{quote(ident)}", f"{self.api_base_url}/file/{quote(ident)}", ] for url in urls: print(f"[download_task_file] trying {url}") try: response = requests.get(url, timeout=self.timeout) print( f"[download_task_file] status={response.status_code} " f"url={url} " f"content_type={response.headers.get('content-type')} " f"content_disposition={response.headers.get('content-disposition')}" ) except requests.RequestException as e: print(f"[download_task_file] request error for {url}: {e}") continue if response.status_code != 200: preview = response.text[:200] if hasattr(response, "text") else "" print(f"[download_task_file] failed {url} body={preview}") continue filename = self._infer_filename(response=response, fallback_name=ident) file_path = self.cache_dir / filename try: with open(file_path, "wb") as f: f.write(response.content) print(f"[download_task_file] saved -> {file_path}") return file_path except OSError as e: print(f"[download_task_file] save error: {e}") return None # 2) Fallback: GAIA dataset repo by task_id -> file_path if task_id: self._ensure_gaia_index() repo_rel_path = (self._task_to_file_path or {}).get(str(task_id), "") if repo_rel_path: print(f"[download_task_file] fallback via GAIA dataset file_path={repo_rel_path}") try: cached_path = hf_hub_download( repo_id="gaia-benchmark/GAIA", repo_type="dataset", filename=repo_rel_path, ) src = Path(cached_path) dst = self.cache_dir / src.name if not dst.exists(): dst.write_bytes(src.read_bytes()) print(f"[download_task_file] dataset fallback saved -> {dst}") return dst except Exception as e: print(f"[download_task_file] dataset fallback ERROR: {e}") return None def read_file_as_text(self, file_path: Path) -> str: suffix = file_path.suffix.lower() try: if suffix in {".txt", ".md", ".html", ".xml", ".csv", ".json", ".py"}: return self._read_supported_text_file(file_path, suffix) if suffix in {".xlsx", ".xls"}: return self._read_excel_preview(file_path) if suffix == "": return self._read_extensionless_file(file_path) return "" except Exception: return "" def _read_supported_text_file(self, file_path: Path, suffix: str) -> str: if suffix in {".txt", ".md", ".html", ".xml", ".py"}: return file_path.read_text(encoding="utf-8", errors="ignore") if suffix == ".json": raw = file_path.read_text(encoding="utf-8", errors="ignore") try: parsed = json.loads(raw) return json.dumps(parsed, indent=2, ensure_ascii=False) except json.JSONDecodeError: return raw if suffix == ".csv": try: df = pd.read_csv(file_path) return df.to_csv(index=False) except Exception: return file_path.read_text(encoding="utf-8", errors="ignore") return "" def _read_excel_preview(self, file_path: Path) -> str: try: xls = pd.ExcelFile(file_path) chunks: list[str] = [] for sheet_name in xls.sheet_names[:5]: try: df = pd.read_excel(file_path, sheet_name=sheet_name) chunks.append(f"Sheet: {sheet_name}") chunks.append(df.head(20).to_csv(index=False)) except Exception: continue return "\n\n".join(chunks).strip() except Exception: return "" def _read_extensionless_file(self, file_path: Path) -> str: try: raw = file_path.read_text(encoding="utf-8", errors="ignore") if raw.strip(): return raw except Exception: pass return "" def _infer_filename(self, response: requests.Response, fallback_name: str) -> str: content_disposition = response.headers.get("content-disposition", "") filename = self._extract_filename_from_content_disposition(content_disposition) if filename: return self._safe_filename(filename) content_type = response.headers.get("content-type", "").lower() extension = self._extension_from_content_type(content_type) fallback_base = self._safe_filename(fallback_name) if Path(fallback_base).suffix: return fallback_base if extension: return f"{fallback_base}{extension}" return fallback_base @staticmethod def _extract_filename_from_content_disposition(content_disposition: str) -> Optional[str]: if "filename=" not in content_disposition: return None try: filename = content_disposition.split("filename=")[-1].strip().strip('"') return filename or None except Exception: return None @staticmethod def _extension_from_content_type(content_type: str) -> str: mapping = { "text/plain": ".txt", "text/csv": ".csv", "application/csv": ".csv", "application/json": ".json", "text/markdown": ".md", "text/html": ".html", "application/xml": ".xml", "text/xml": ".xml", "application/pdf": ".pdf", "image/png": ".png", "image/jpeg": ".jpg", "image/jpg": ".jpg", "image/webp": ".webp", "audio/mpeg": ".mp3", "audio/mp3": ".mp3", "audio/wav": ".wav", "audio/x-wav": ".wav", "audio/mp4": ".m4a", "audio/x-m4a": ".m4a", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", "application/vnd.ms-excel": ".xls", "text/x-python": ".py", "text/python": ".py", } for key, ext in mapping.items(): if key in content_type: return ext return "" @staticmethod def _safe_filename(filename: str) -> str: return os.path.basename(filename)