gaia_unit4_space / tools /media_tools.py
hawkdev's picture
fixing errors
f11f984
import base64
import os
from pathlib import Path
from typing import Any, Optional
from inference_client_factory import make_inference_client
def _groq_openai_client():
k = os.environ.get("GROQ_API_KEY", "").strip()
if not k:
return None
from openai import OpenAI
return OpenAI(api_key=k, base_url="https://api.groq.com/openai/v1")
def _openai_platform_client():
k = os.environ.get("OPENAI_API_KEY", "").strip()
if not k:
return None
from openai import OpenAI
base = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1").strip()
return OpenAI(api_key=k, base_url=base)
def transcribe_audio(
file_path: str,
*,
hf_token: Optional[str] = None,
model: Optional[str] = None,
) -> str:
gc = _groq_openai_client()
if gc:
mid = model or os.environ.get("GAIA_GROQ_ASR_MODEL", "whisper-large-v3")
try:
with open(file_path, "rb") as audio_f:
tr = gc.audio.transcriptions.create(
model=mid,
file=audio_f,
)
return (tr.text or "").strip()
except Exception as e:
return f"ASR error (Groq): {e}"
oc = _openai_platform_client()
if oc:
mid = model or os.environ.get("GAIA_OPENAI_ASR_MODEL", "whisper-1")
try:
with open(file_path, "rb") as audio_f:
tr = oc.audio.transcriptions.create(model=mid, file=audio_f)
return (tr.text or "").strip()
except Exception as e:
return f"ASR error (OpenAI): {e}"
token = hf_token or os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACEHUB_API_TOKEN"
)
if not token:
return (
"Error: set GROQ_API_KEY (free), OPENAI_API_KEY, or HF_TOKEN for speech."
)
mid = model or os.environ.get("GAIA_ASR_MODEL", "openai/whisper-large-v3")
client = make_inference_client(token)
try:
out = client.automatic_speech_recognition(file_path, model=mid)
return (out.text or "").strip()
except Exception as e:
return f"ASR error: {e}"
def _vision_chat_openai(
client: Any,
*,
model: str,
file_path: Path,
question: str,
) -> str:
raw = file_path.read_bytes()
b64 = base64.b64encode(raw).decode("ascii")
mime = "image/png" if file_path.suffix.lower() == ".png" else "image/jpeg"
data_url = f"data:{mime};base64,{b64}"
comp = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {"url": data_url},
},
],
}
],
max_tokens=512,
temperature=0.2,
)
return (comp.choices[0].message.content or "").strip()
def analyze_image_with_vlm(
file_path: str,
question: str,
*,
hf_token: Optional[str] = None,
model: Optional[str] = None,
) -> str:
path = Path(file_path)
if not path.is_file():
return f"Error: image not found: {file_path}"
gc = _groq_openai_client()
if gc:
mid = model or os.environ.get(
"GAIA_GROQ_VISION_MODEL",
"llama-3.2-11b-vision-preview",
)
try:
return _vision_chat_openai(gc, model=mid, file_path=path, question=question)
except Exception as e:
return f"Vision error (Groq): {e}"
oc = _openai_platform_client()
if oc:
mid = model or os.environ.get("GAIA_OPENAI_VISION_MODEL", "gpt-4o-mini")
try:
return _vision_chat_openai(oc, model=mid, file_path=path, question=question)
except Exception as e:
return f"Vision error (OpenAI): {e}"
token = hf_token or os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACEHUB_API_TOKEN"
)
if not token:
return (
"Error: set GROQ_API_KEY, OPENAI_API_KEY, or HF_TOKEN for vision."
)
mid = model or os.environ.get(
"GAIA_VISION_MODEL", "meta-llama/Llama-3.2-11B-Vision-Instruct"
)
raw = path.read_bytes()
b64 = base64.b64encode(raw).decode("ascii")
mime = "image/png" if path.suffix.lower() == ".png" else "image/jpeg"
data_url = f"data:{mime};base64,{b64}"
client = make_inference_client(token)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
]
try:
comp = client.chat_completion(
messages=messages,
model=mid,
max_tokens=512,
temperature=0.2,
)
msg = comp.choices[0].message
return (msg.content or "").strip()
except Exception as e:
return f"Vision error: {e}"
def visual_question_short(
file_path: str,
question: str,
*,
hf_token: Optional[str] = None,
model: Optional[str] = None,
) -> str:
if _groq_openai_client() or _openai_platform_client():
return analyze_image_with_vlm(
file_path, question, hf_token=hf_token, model=model
)
token = hf_token or os.environ.get("HF_TOKEN") or os.environ.get(
"HUGGINGFACEHUB_API_TOKEN"
)
if not token:
return "Error: HF_TOKEN not set for VQA."
mid = model or "Salesforce/blip-vqa-base"
client = make_inference_client(token)
try:
answers = client.visual_question_answering(
image=file_path, question=question, model=mid, top_k=5
)
lines = [f"{a.answer} ({a.score:.3f})" for a in answers]
return "\n".join(lines)
except Exception as e:
return f"VQA error: {e}"