abhi1294's picture
Fix prompts and utils
900ed7a
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Optional
from urllib.parse import quote
import pandas as pd
import requests
from datasets import load_dataset
from huggingface_hub import hf_hub_download
class TaskFileTool:
def __init__(self, api_base_url: str, cache_dir: str = "task_files", timeout: int = 30):
self.api_base_url = api_base_url.rstrip("/")
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.timeout = timeout
self._task_to_file_path: dict[str, str] | None = None
def _ensure_gaia_index(self) -> None:
if self._task_to_file_path is not None:
return
print("[TaskFileTool] Loading GAIA validation split for file-path index...")
self._task_to_file_path = {}
try:
ds = load_dataset(
"gaia-benchmark/GAIA",
"2023_level1",
split="validation",
)
for ex in ds:
ex_dict = dict(ex)
task_id = str(ex_dict.get("task_id", "")).strip()
file_path = str(ex_dict.get("file_path", "")).strip()
if task_id and file_path:
self._task_to_file_path[task_id] = file_path
print(f"[TaskFileTool] Indexed {len(self._task_to_file_path)} task_id -> file_path rows")
except Exception as e:
print(f"[TaskFileTool] Failed to build GAIA index: {e}")
self._task_to_file_path = {}
def get_task_context(self, task_id: str) -> str:
file_path = self.download_task_file(task_id=task_id)
if file_path is None:
return ""
return self.read_file_as_text(file_path)
def download_task_file(
self,
task_id: Optional[str] = None,
file_name: Optional[str] = None,
) -> Optional[Path]:
"""
Strategy:
1. Try scorer endpoints
2. Fallback to GAIA dataset file_path via hf_hub_download
"""
# 0) If file_name already exists in cache, use it immediately
if file_name:
local_candidate = self.cache_dir / Path(str(file_name)).name
if local_candidate.exists():
print(f"[download_task_file] using cached file -> {local_candidate}")
return local_candidate
# 1) Try scorer API first
candidates: list[str] = []
if file_name:
candidates.append(str(file_name))
candidates.append(Path(str(file_name)).name)
if task_id:
candidates.append(str(task_id))
seen: set[str] = set()
for ident in candidates:
ident = str(ident).strip()
if not ident or ident in seen:
continue
seen.add(ident)
urls = [
f"{self.api_base_url}/files/{quote(ident)}",
f"{self.api_base_url}/file/{quote(ident)}",
]
for url in urls:
print(f"[download_task_file] trying {url}")
try:
response = requests.get(url, timeout=self.timeout)
print(
f"[download_task_file] status={response.status_code} "
f"url={url} "
f"content_type={response.headers.get('content-type')} "
f"content_disposition={response.headers.get('content-disposition')}"
)
except requests.RequestException as e:
print(f"[download_task_file] request error for {url}: {e}")
continue
if response.status_code != 200:
preview = response.text[:200] if hasattr(response, "text") else ""
print(f"[download_task_file] failed {url} body={preview}")
continue
filename = self._infer_filename(response=response, fallback_name=ident)
file_path = self.cache_dir / filename
try:
with open(file_path, "wb") as f:
f.write(response.content)
print(f"[download_task_file] saved -> {file_path}")
return file_path
except OSError as e:
print(f"[download_task_file] save error: {e}")
return None
# 2) Fallback: GAIA dataset repo by task_id -> file_path
if task_id:
self._ensure_gaia_index()
repo_rel_path = (self._task_to_file_path or {}).get(str(task_id), "")
if repo_rel_path:
print(f"[download_task_file] fallback via GAIA dataset file_path={repo_rel_path}")
try:
cached_path = hf_hub_download(
repo_id="gaia-benchmark/GAIA",
repo_type="dataset",
filename=repo_rel_path,
)
src = Path(cached_path)
dst = self.cache_dir / src.name
if not dst.exists():
dst.write_bytes(src.read_bytes())
print(f"[download_task_file] dataset fallback saved -> {dst}")
return dst
except Exception as e:
print(f"[download_task_file] dataset fallback ERROR: {e}")
return None
def read_file_as_text(self, file_path: Path) -> str:
suffix = file_path.suffix.lower()
try:
if suffix in {".txt", ".md", ".html", ".xml", ".csv", ".json", ".py"}:
return self._read_supported_text_file(file_path, suffix)
if suffix in {".xlsx", ".xls"}:
return self._read_excel_preview(file_path)
if suffix == "":
return self._read_extensionless_file(file_path)
return ""
except Exception:
return ""
def _read_supported_text_file(self, file_path: Path, suffix: str) -> str:
if suffix in {".txt", ".md", ".html", ".xml", ".py"}:
return file_path.read_text(encoding="utf-8", errors="ignore")
if suffix == ".json":
raw = file_path.read_text(encoding="utf-8", errors="ignore")
try:
parsed = json.loads(raw)
return json.dumps(parsed, indent=2, ensure_ascii=False)
except json.JSONDecodeError:
return raw
if suffix == ".csv":
try:
df = pd.read_csv(file_path)
return df.to_csv(index=False)
except Exception:
return file_path.read_text(encoding="utf-8", errors="ignore")
return ""
def _read_excel_preview(self, file_path: Path) -> str:
try:
xls = pd.ExcelFile(file_path)
chunks: list[str] = []
for sheet_name in xls.sheet_names[:5]:
try:
df = pd.read_excel(file_path, sheet_name=sheet_name)
chunks.append(f"Sheet: {sheet_name}")
chunks.append(df.head(20).to_csv(index=False))
except Exception:
continue
return "\n\n".join(chunks).strip()
except Exception:
return ""
def _read_extensionless_file(self, file_path: Path) -> str:
try:
raw = file_path.read_text(encoding="utf-8", errors="ignore")
if raw.strip():
return raw
except Exception:
pass
return ""
def _infer_filename(self, response: requests.Response, fallback_name: str) -> str:
content_disposition = response.headers.get("content-disposition", "")
filename = self._extract_filename_from_content_disposition(content_disposition)
if filename:
return self._safe_filename(filename)
content_type = response.headers.get("content-type", "").lower()
extension = self._extension_from_content_type(content_type)
fallback_base = self._safe_filename(fallback_name)
if Path(fallback_base).suffix:
return fallback_base
if extension:
return f"{fallback_base}{extension}"
return fallback_base
@staticmethod
def _extract_filename_from_content_disposition(content_disposition: str) -> Optional[str]:
if "filename=" not in content_disposition:
return None
try:
filename = content_disposition.split("filename=")[-1].strip().strip('"')
return filename or None
except Exception:
return None
@staticmethod
def _extension_from_content_type(content_type: str) -> str:
mapping = {
"text/plain": ".txt",
"text/csv": ".csv",
"application/csv": ".csv",
"application/json": ".json",
"text/markdown": ".md",
"text/html": ".html",
"application/xml": ".xml",
"text/xml": ".xml",
"application/pdf": ".pdf",
"image/png": ".png",
"image/jpeg": ".jpg",
"image/jpg": ".jpg",
"image/webp": ".webp",
"audio/mpeg": ".mp3",
"audio/mp3": ".mp3",
"audio/wav": ".wav",
"audio/x-wav": ".wav",
"audio/mp4": ".m4a",
"audio/x-m4a": ".m4a",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls",
"text/x-python": ".py",
"text/python": ".py",
}
for key, ext in mapping.items():
if key in content_type:
return ext
return ""
@staticmethod
def _safe_filename(filename: str) -> str:
return os.path.basename(filename)