Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| from urllib.parse import quote | |
| import pandas as pd | |
| import requests | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| class TaskFileTool: | |
| def __init__(self, api_base_url: str, cache_dir: str = "task_files", timeout: int = 30): | |
| self.api_base_url = api_base_url.rstrip("/") | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.timeout = timeout | |
| self._task_to_file_path: dict[str, str] | None = None | |
| def _ensure_gaia_index(self) -> None: | |
| if self._task_to_file_path is not None: | |
| return | |
| print("[TaskFileTool] Loading GAIA validation split for file-path index...") | |
| self._task_to_file_path = {} | |
| try: | |
| ds = load_dataset( | |
| "gaia-benchmark/GAIA", | |
| "2023_level1", | |
| split="validation", | |
| ) | |
| for ex in ds: | |
| ex_dict = dict(ex) | |
| task_id = str(ex_dict.get("task_id", "")).strip() | |
| file_path = str(ex_dict.get("file_path", "")).strip() | |
| if task_id and file_path: | |
| self._task_to_file_path[task_id] = file_path | |
| print(f"[TaskFileTool] Indexed {len(self._task_to_file_path)} task_id -> file_path rows") | |
| except Exception as e: | |
| print(f"[TaskFileTool] Failed to build GAIA index: {e}") | |
| self._task_to_file_path = {} | |
| def get_task_context(self, task_id: str) -> str: | |
| file_path = self.download_task_file(task_id=task_id) | |
| if file_path is None: | |
| return "" | |
| return self.read_file_as_text(file_path) | |
| def download_task_file( | |
| self, | |
| task_id: Optional[str] = None, | |
| file_name: Optional[str] = None, | |
| ) -> Optional[Path]: | |
| """ | |
| Strategy: | |
| 1. Try scorer endpoints | |
| 2. Fallback to GAIA dataset file_path via hf_hub_download | |
| """ | |
| # 0) If file_name already exists in cache, use it immediately | |
| if file_name: | |
| local_candidate = self.cache_dir / Path(str(file_name)).name | |
| if local_candidate.exists(): | |
| print(f"[download_task_file] using cached file -> {local_candidate}") | |
| return local_candidate | |
| # 1) Try scorer API first | |
| candidates: list[str] = [] | |
| if file_name: | |
| candidates.append(str(file_name)) | |
| candidates.append(Path(str(file_name)).name) | |
| if task_id: | |
| candidates.append(str(task_id)) | |
| seen: set[str] = set() | |
| for ident in candidates: | |
| ident = str(ident).strip() | |
| if not ident or ident in seen: | |
| continue | |
| seen.add(ident) | |
| urls = [ | |
| f"{self.api_base_url}/files/{quote(ident)}", | |
| f"{self.api_base_url}/file/{quote(ident)}", | |
| ] | |
| for url in urls: | |
| print(f"[download_task_file] trying {url}") | |
| try: | |
| response = requests.get(url, timeout=self.timeout) | |
| print( | |
| f"[download_task_file] status={response.status_code} " | |
| f"url={url} " | |
| f"content_type={response.headers.get('content-type')} " | |
| f"content_disposition={response.headers.get('content-disposition')}" | |
| ) | |
| except requests.RequestException as e: | |
| print(f"[download_task_file] request error for {url}: {e}") | |
| continue | |
| if response.status_code != 200: | |
| preview = response.text[:200] if hasattr(response, "text") else "" | |
| print(f"[download_task_file] failed {url} body={preview}") | |
| continue | |
| filename = self._infer_filename(response=response, fallback_name=ident) | |
| file_path = self.cache_dir / filename | |
| try: | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| print(f"[download_task_file] saved -> {file_path}") | |
| return file_path | |
| except OSError as e: | |
| print(f"[download_task_file] save error: {e}") | |
| return None | |
| # 2) Fallback: GAIA dataset repo by task_id -> file_path | |
| if task_id: | |
| self._ensure_gaia_index() | |
| repo_rel_path = (self._task_to_file_path or {}).get(str(task_id), "") | |
| if repo_rel_path: | |
| print(f"[download_task_file] fallback via GAIA dataset file_path={repo_rel_path}") | |
| try: | |
| cached_path = hf_hub_download( | |
| repo_id="gaia-benchmark/GAIA", | |
| repo_type="dataset", | |
| filename=repo_rel_path, | |
| ) | |
| src = Path(cached_path) | |
| dst = self.cache_dir / src.name | |
| if not dst.exists(): | |
| dst.write_bytes(src.read_bytes()) | |
| print(f"[download_task_file] dataset fallback saved -> {dst}") | |
| return dst | |
| except Exception as e: | |
| print(f"[download_task_file] dataset fallback ERROR: {e}") | |
| return None | |
| def read_file_as_text(self, file_path: Path) -> str: | |
| suffix = file_path.suffix.lower() | |
| try: | |
| if suffix in {".txt", ".md", ".html", ".xml", ".csv", ".json", ".py"}: | |
| return self._read_supported_text_file(file_path, suffix) | |
| if suffix in {".xlsx", ".xls"}: | |
| return self._read_excel_preview(file_path) | |
| if suffix == "": | |
| return self._read_extensionless_file(file_path) | |
| return "" | |
| except Exception: | |
| return "" | |
| def _read_supported_text_file(self, file_path: Path, suffix: str) -> str: | |
| if suffix in {".txt", ".md", ".html", ".xml", ".py"}: | |
| return file_path.read_text(encoding="utf-8", errors="ignore") | |
| if suffix == ".json": | |
| raw = file_path.read_text(encoding="utf-8", errors="ignore") | |
| try: | |
| parsed = json.loads(raw) | |
| return json.dumps(parsed, indent=2, ensure_ascii=False) | |
| except json.JSONDecodeError: | |
| return raw | |
| if suffix == ".csv": | |
| try: | |
| df = pd.read_csv(file_path) | |
| return df.to_csv(index=False) | |
| except Exception: | |
| return file_path.read_text(encoding="utf-8", errors="ignore") | |
| return "" | |
| def _read_excel_preview(self, file_path: Path) -> str: | |
| try: | |
| xls = pd.ExcelFile(file_path) | |
| chunks: list[str] = [] | |
| for sheet_name in xls.sheet_names[:5]: | |
| try: | |
| df = pd.read_excel(file_path, sheet_name=sheet_name) | |
| chunks.append(f"Sheet: {sheet_name}") | |
| chunks.append(df.head(20).to_csv(index=False)) | |
| except Exception: | |
| continue | |
| return "\n\n".join(chunks).strip() | |
| except Exception: | |
| return "" | |
| def _read_extensionless_file(self, file_path: Path) -> str: | |
| try: | |
| raw = file_path.read_text(encoding="utf-8", errors="ignore") | |
| if raw.strip(): | |
| return raw | |
| except Exception: | |
| pass | |
| return "" | |
| def _infer_filename(self, response: requests.Response, fallback_name: str) -> str: | |
| content_disposition = response.headers.get("content-disposition", "") | |
| filename = self._extract_filename_from_content_disposition(content_disposition) | |
| if filename: | |
| return self._safe_filename(filename) | |
| content_type = response.headers.get("content-type", "").lower() | |
| extension = self._extension_from_content_type(content_type) | |
| fallback_base = self._safe_filename(fallback_name) | |
| if Path(fallback_base).suffix: | |
| return fallback_base | |
| if extension: | |
| return f"{fallback_base}{extension}" | |
| return fallback_base | |
| def _extract_filename_from_content_disposition(content_disposition: str) -> Optional[str]: | |
| if "filename=" not in content_disposition: | |
| return None | |
| try: | |
| filename = content_disposition.split("filename=")[-1].strip().strip('"') | |
| return filename or None | |
| except Exception: | |
| return None | |
| def _extension_from_content_type(content_type: str) -> str: | |
| mapping = { | |
| "text/plain": ".txt", | |
| "text/csv": ".csv", | |
| "application/csv": ".csv", | |
| "application/json": ".json", | |
| "text/markdown": ".md", | |
| "text/html": ".html", | |
| "application/xml": ".xml", | |
| "text/xml": ".xml", | |
| "application/pdf": ".pdf", | |
| "image/png": ".png", | |
| "image/jpeg": ".jpg", | |
| "image/jpg": ".jpg", | |
| "image/webp": ".webp", | |
| "audio/mpeg": ".mp3", | |
| "audio/mp3": ".mp3", | |
| "audio/wav": ".wav", | |
| "audio/x-wav": ".wav", | |
| "audio/mp4": ".m4a", | |
| "audio/x-m4a": ".m4a", | |
| "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", | |
| "application/vnd.ms-excel": ".xls", | |
| "text/x-python": ".py", | |
| "text/python": ".py", | |
| } | |
| for key, ext in mapping.items(): | |
| if key in content_type: | |
| return ext | |
| return "" | |
| def _safe_filename(filename: str) -> str: | |
| return os.path.basename(filename) |