Final_Assignment_Template / tools /attachments.py
jwlee-ai's picture
Upload folder using huggingface_hub
4a5f5e9 verified
"""GAIA 첨뢀 파일 처리 + μ§ˆλ¬Έβ†”task_id 인덱슀.
CodeAgent의 μ‹œκ·Έλ‹ˆμ²˜ μ œμ•½(__call__이 question만 λ°›μŒ) λ•Œλ¬Έμ— task_idλ₯Ό 직접
μ£Όμž…ν•  수 μ—†μ–΄, λͺ¨λ“ˆ μ „μ—­ mutable μ»¨ν…Œμ΄λ„ˆ + prefetch 인덱슀둜 μš°νšŒν•œλ‹€.
흐름:
1) BasicAgent.__init__ 단계에 prefetch_question_index() β†’ /questions 1회 호좜
ν•΄μ„œ {질문본문: task_id} 사전을 λ§Œλ“€κ³  set_question_index() 둜 등둝.
2) BasicAgent.__call__ μ§„μž… μ‹œ set_current_task(question) 으둜 ν˜„μž¬ 문제의
task_id와 질문 본문을 _CURRENT_TASK 에 μ„ΈνŒ….
3) μ—μ΄μ „νŠΈκ°€ get_attached_file() 을 인자 없이 ν˜ΈμΆœν•˜λ©΄ _CURRENT_TASK 의
task_id둜 채점 μ„œλ²„μ—μ„œ νŒŒμΌμ„ λ°›μ•„μ˜€κ³ , νƒ€μž…λ³„λ‘œ 처리:
- ν…μŠ€νŠΈ/CSV/JSON/code: UTF-8 λ””μ½”λ”©
- Excel(.xlsx): μ‹œνŠΈλ³„ CSV
- PDF: νŽ˜μ΄μ§€λ³„ ν…μŠ€νŠΈ μΆ”μΆœ (pypdf)
- 이미지: VLM(Qwen2.5-VL-7B)으둜 ν˜„μž¬ 질문 μ»¨ν…μŠ€νŠΈμ— 맞좰 뢄석
- μ˜€λ””μ˜€: Whisper(large-v3) 전사
"""
import io
import re
import requests
from smolagents import tool
# 채점 μ„œλ²„ URL을 μ—¬κΈ°μ„œλ„ ν•œ 번 μ •μ˜ (app.py와 동일 κ°’).
# tools λͺ¨λ“ˆμ„ λ…λ¦½μ μœΌλ‘œ μ‚¬μš©ν•˜λ”λΌλ„ μ˜λ―Έκ°€ ν†΅ν•˜λ„λ‘ 뢄리해 λ‘”λ‹€.
_DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# BasicAgent.__call__ μ§„μž… μ‹œ κ°±μ‹ λ˜λŠ” mutable μ»¨ν…Œμ΄λ„ˆ.
# question은 이미지 VLM 호좜 μ‹œ μ»¨ν…μŠ€νŠΈ(prompt)둜 μ‚¬μš©λœλ‹€.
_CURRENT_TASK = {"id": None, "question": None}
# question.strip() -> task_id 사전.
_QUESTION_INDEX: dict = {}
def prefetch_question_index() -> dict:
"""채점 μ„œλ²„ /questions λ₯Ό ν•œ 번 ν˜ΈμΆœν•΄ {질문본문: task_id} 사전을 λΉŒλ“œν•œλ‹€.
μ‹€νŒ¨ν•΄λ„ 빈 dictλ₯Ό λ°˜ν™˜ν•΄ μ—μ΄μ „νŠΈκ°€ 첨뢀 μ—†λŠ” λ¬Έμ œλ§Œμ΄λΌλ„ ν’€ 수 있게 ν•œλ‹€."""
try:
r = requests.get(f"{_DEFAULT_API_URL}/questions", timeout=15)
r.raise_for_status()
idx = {}
for item in r.json():
qt = (item.get("question") or "").strip()
tid = item.get("task_id")
if qt and tid:
if qt in idx and idx[qt] != tid:
print(
"Warning: duplicate question text in prefetch index β€” "
f"task_id {idx[qt]!r} will be overwritten by {tid!r}"
)
idx[qt] = tid
return idx
except Exception as e:
print(f"Warning: could not prefetch question index: {e}")
return {}
def set_question_index(idx: dict) -> None:
"""BasicAgent.__init__μ—μ„œ prefetch κ²°κ³Όλ₯Ό λͺ¨λ“ˆ 전역에 λ°•μ•„μ£ΌλŠ” μ„Έν„°."""
global _QUESTION_INDEX
_QUESTION_INDEX = idx
def set_current_task(question: str):
"""BasicAgent.__call__ μ§„μž… μ‹œ ν˜„μž¬ 문제의 task_id와 질문 본문을 λͺ¨λ“ˆ 전역에 μ„ΈνŒ….
질문 본문은 이미지 μ²¨λΆ€μ˜ VLM ν˜ΈμΆœμ— prompt μ»¨ν…μŠ€νŠΈλ‘œ μ „λ‹¬λœλ‹€.
λ§€μΉ­ μ‹€νŒ¨ μ‹œ task_id둜 None이 λ“€μ–΄κ°€μ§€λ§Œ question은 κ·ΈλŒ€λ‘œ μ €μž₯λœλ‹€."""
tid = _QUESTION_INDEX.get(question.strip())
_CURRENT_TASK["id"] = tid
_CURRENT_TASK["question"] = question
return tid
# --- 파일 νƒ€μž… λΆ„κΈ° 헬퍼 ---
def _extract_filename(headers, url: str) -> str:
"""Content-Disposition ν—€λ”μ—μ„œ filename을 λ½‘κ±°λ‚˜, URL λλΆ€λΆ„μœΌλ‘œ 폴백.
채점 μ„œλ²„κ°€ Content-Type을 octet-stream으둜 쀄 λ•Œ ν™•μž₯자둜 λ³΄κ°•ν•˜κΈ° μœ„ν•¨."""
cd = headers.get("Content-Disposition", "")
# filename* (RFC 5987) 와 filename= μ–‘μͺ½ λ‹€ 처리.
m = re.search(r'filename\*?=(?:UTF-8\'\')?"?([^";\r\n]+)"?', cd)
if m:
return m.group(1).strip().strip('"')
return url.rsplit("/", 1)[-1]
def _is_excel(content_type: str, ext: str) -> bool:
if ext in ("xlsx", "xls"):
return True
ct = content_type.lower()
return "spreadsheet" in ct or ct.endswith("xlsx") or ct.endswith("xls") or "excel" in ct
def _is_pdf(content_type: str, ext: str) -> bool:
return ext == "pdf" or "pdf" in content_type.lower()
def _is_image(content_type: str, ext: str) -> bool:
return ext in ("png", "jpg", "jpeg", "webp", "gif", "bmp") \
or content_type.lower().startswith("image/")
def _is_audio(content_type: str, ext: str) -> bool:
return ext in ("mp3", "wav", "m4a", "ogg", "flac") \
or content_type.lower().startswith("audio/")
# --- νƒ€μž…λ³„ ν•Έλ“€λŸ¬ ---
def _handle_excel(content: bytes, content_type: str) -> str:
"""xlsx β†’ μ‹œνŠΈλ³„ CSV둜 직렬화. GAIA에 맀좜/판맀 데이터 λ¬Έμ œκ°€ 자주 λ‚˜μ˜¨λ‹€."""
try:
import pandas as _pd
bio = io.BytesIO(content)
sheets = _pd.read_excel(bio, sheet_name=None)
parts = []
for name, df in sheets.items():
parts.append(f"--- Sheet: {name} ---\n{df.to_csv(index=False)}")
combined = "\n\n".join(parts)
if len(combined) > 12000:
combined = combined[:12000] + "\n...[truncated]"
return f"[Content-Type: {content_type}]\n{combined}"
except Exception as e:
return f"Excel parse error: {e}"
def _handle_pdf(content: bytes, content_type: str) -> str:
"""pypdf둜 PDF λ³Έλ¬Έ ν…μŠ€νŠΈ μΆ”μΆœ. νŽ˜μ΄μ§€λ³„λ‘œ κ΅¬λΆ„ν•΄μ„œ λ°˜ν™˜.
μŠ€μΊ” PDF(μ΄λ―Έμ§€λ‘œ 된)λŠ” ν…μŠ€νŠΈκ°€ λΉ„κ±°λ‚˜ 깨질 수 μžˆλŠ”λ°, κ·Έ κ²½μš°λŠ”
LLM이 μœ„ν‚€/μ›Ήκ²€μƒ‰μœΌλ‘œ ν΄λ°±ν•˜λ„λ‘ μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈκ°€ μœ λ„ν•œλ‹€."""
try:
from pypdf import PdfReader
bio = io.BytesIO(content)
reader = PdfReader(bio)
parts = []
for i, page in enumerate(reader.pages):
try:
txt = page.extract_text() or ""
except Exception as pe:
txt = f"(extraction failed: {pe})"
parts.append(f"--- Page {i+1} ---\n{txt}")
combined = "\n\n".join(parts)
if len(combined) > 12000:
combined = combined[:12000] + "\n...[truncated]"
return f"[PDF, {len(reader.pages)} pages, Content-Type: {content_type}]\n{combined}"
except Exception as e:
return f"PDF parse error: {e}"
def _handle_image(content: bytes, content_type: str) -> str:
"""VLM(Qwen2.5-VL-7B)으둜 ν˜„μž¬ 질문 μ»¨ν…μŠ€νŠΈμ— 맞좰 이미지λ₯Ό λΆ„μ„ν•œλ‹€.
HF Inference API의 OpenAI ν˜Έν™˜ chat_completion으둜 base64 data URL을 μ „μ†‘ν•œλ‹€.
질문 μ»¨ν…μŠ€νŠΈκ°€ 있으면 κ·Έκ±Έ κ·ΈλŒ€λ‘œ prompt에 λ°•μ•„ 정닡에 직접 도움이 λ˜λŠ”
λΆ€λΆ„λ§Œ 뽑아내도둝 μœ λ„(generic μΊ‘μ…˜μ€ λ””ν…ŒμΌμ„ 놓침). 호좜 μ‹€νŒ¨ μ‹œ μ—λŸ¬
λ¬Έμžμ—΄μ„ λ°˜ν™˜ν•΄ μ—μ΄μ „νŠΈκ°€ λ‹€λ₯Έ μ „λž΅μœΌλ‘œ 폴백할 수 있게 ν•œλ‹€.
HF_TOKEN ν™˜κ²½λ³€μˆ˜κ°€ ν•„μš”ν•˜λ‹€. Space 배포 μ‹œμ—λŠ” Space secrets에 등둝해야 함.
"""
try:
import base64
from huggingface_hub import InferenceClient
question = (_CURRENT_TASK.get("question") or "").strip()
# 데이터 URL ꡬ성. content_type이 image/* κ°€ 아닐 μˆ˜λ„ μžˆμ–΄ μ•ˆμ „ν•˜κ²Œ 폴백.
mime = content_type.split(";")[0].strip()
if not mime.startswith("image/"):
mime = "image/png"
b64 = base64.b64encode(content).decode("utf-8")
data_url = f"data:{mime};base64,{b64}"
if question:
prompt = (
"Analyze the attached image and answer the following question. "
"Read any text, numbers, or labels visible in the image. "
"If it is a chart or table, extract the relevant data values precisely.\n\n"
f"Question: {question}"
)
else:
prompt = (
"Describe the attached image in detail, including any visible text, "
"numbers, or labels."
)
client = InferenceClient(provider="auto") # HF_TOKEN ν™˜κ²½λ³€μˆ˜ μ‚¬μš©
resp = client.chat_completion(
model="Qwen/Qwen2.5-VL-7B-Instruct",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
],
max_tokens=1024,
)
analysis = resp.choices[0].message.content
return (
f"[Image analysis (Content-Type: {content_type}, {len(content)} bytes)]\n"
f"{analysis}"
)
except Exception as e:
return (
f"Image attached (Content-Type: {content_type}, {len(content)} bytes). "
f"VLM analysis failed: {e}"
)
def _handle_audio(content: bytes, content_type: str) -> str:
"""Whisper(large-v3)둜 μ˜€λ””μ˜€ 전사. GAIA μ˜€λ””μ˜€λŠ” 보톡 짧은 λ°œν™”λΌ ν•œ 번 호좜둜 μΆ©λΆ„.
HF_TOKEN ν™˜κ²½λ³€μˆ˜κ°€ ν•„μš”ν•˜λ‹€. Space 배포 μ‹œμ—λŠ” Space secrets에 등둝해야 함.
"""
try:
from huggingface_hub import InferenceClient
client = InferenceClient(provider="auto")
result = client.automatic_speech_recognition(
audio=content,
model="openai/whisper-large-v3",
)
# huggingface_hub 버전에 따라 dict λ˜λŠ” dataclass-like 객체둜 λ°˜ν™˜λ˜λ―€λ‘œ
# μ–‘μͺ½ λͺ¨λ‘ μ²˜λ¦¬ν•œλ‹€.
if hasattr(result, "text"):
transcription = result.text
elif isinstance(result, dict):
transcription = result.get("text", str(result))
else:
transcription = str(result)
return (
f"[Audio transcription (Content-Type: {content_type}, {len(content)} bytes)]\n"
f"{transcription}"
)
except Exception as e:
return (
f"Audio attached (Content-Type: {content_type}, {len(content)} bytes). "
f"Transcription failed: {e}"
)
@tool
def get_attached_file() -> str:
"""Download the file attached to the CURRENT GAIA task and return its content.
Takes no arguments β€” the current task_id is auto-resolved from the question.
Use this whenever the question references a file, spreadsheet, image, audio, PDF, code listing,
CSV, or any external resource. Returns:
- Text/CSV/JSON/code: the decoded text (truncated to ~12k chars).
- Excel (.xlsx): each sheet rendered as CSV (truncated).
- PDF: extracted text per page (truncated).
- Image (PNG/JPEG/WEBP/GIF/BMP): a vision-language model analysis focused on the current question.
- Audio (MP3/WAV/M4A/OGG/FLAC): a Whisper transcription.
- Other binary: a metadata description (size + content-type).
"""
# μ‹œκ·Έλ‹ˆμ²˜ μ œμ•½ λ•Œλ¬Έμ— task_id 인자λ₯Ό λ°›μ§€ μ•Šκ³ , λͺ¨λ“ˆ μ „μ—­ _CURRENT_TASK μ—μ„œ κ°€μ Έμ˜¨λ‹€.
# 이 값은 BasicAgent.__call__ μ§„μž… μ‹œ set_current_task()둜 μ„ΈνŒ…λœλ‹€.
task_id = _CURRENT_TASK.get("id")
if not task_id:
return "No task context available β€” likely no file attached for this question."
try:
url = f"{_DEFAULT_API_URL}/files/{task_id}"
r = requests.get(url, timeout=30)
if r.status_code == 404:
return "No file attached to this task."
r.raise_for_status()
content_type = r.headers.get("Content-Type", "")
filename = _extract_filename(r.headers, url)
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
# 1) λͺ…ν™•ν•œ λ°”μ΄λ„ˆλ¦¬ νƒ€μž…μ€ λ¨Όμ € 처리.
# 일뢀 PDF/SVGλŠ” UTF-8 decodeκ°€ λ˜μ–΄λ„ μ›μ‹œ ν…μŠ€νŠΈλ‘œ λ°˜ν™˜ν•˜λ©΄ ν’ˆμ§ˆμ΄ 크게 λ–¨μ–΄μ§„λ‹€.
if _is_excel(content_type, ext):
return _handle_excel(r.content, content_type)
if _is_pdf(content_type, ext):
return _handle_pdf(r.content, content_type)
if _is_image(content_type, ext):
return _handle_image(r.content, content_type)
if _is_audio(content_type, ext):
return _handle_audio(r.content, content_type)
# 2) ν…μŠ€νŠΈ 계열이면 UTF-8둜 λ°˜ν™˜.
try:
text = r.content.decode("utf-8")
if len(text) > 12000:
text = text[:12000] + "\n...[truncated]"
return f"[Content-Type: {content_type}]\n{text}"
except UnicodeDecodeError:
pass
# 3) μ•Œ 수 μ—†λŠ” λ°”μ΄λ„ˆλ¦¬ β€” λ©”νƒ€λ°μ΄ν„°λ§Œ λ°˜ν™˜.
return (
f"Binary file (Content-Type: {content_type}, "
f"size: {len(r.content)} bytes). Cannot display as text. "
f"URL: {url}"
)
except Exception as e:
return f"get_attached_file error: {e}"