DatasetChecker / gemini_batch.py
archivartaunik's picture
Upload 8 files
45db9d9 verified
"""Gemini BATCH API helper."""
from __future__ import annotations
"""Helpers for running Gemini BATCH API jobs."""
from dataclasses import dataclass
import json
import mimetypes
import os
import tempfile
import time
from typing import Any, Dict, Iterable, List, Optional
import requests
from calls_analyser.domain.exceptions import AIModelError
try: # pragma: no cover - optional dependency wiring
from google import genai
from google.genai import types
except Exception: # pragma: no cover - optional dependency wiring
genai = None # type: ignore
types = None # type: ignore
@dataclass
class BatchTask:
"""Represents a single audio file queued for BATCH processing."""
key: str
path: str
mime_type: str
file_uri: Optional[str] = None
class GeminiBatchRunner:
"""Create and poll Gemini BATCH jobs for multiple audio files."""
def __init__(self, api_key: str, model: str) -> None:
if genai is None:
raise AIModelError("google-genai library is not available")
self._api_key = api_key
self._model = model
self._client = genai.Client(api_key=api_key)
def run_batch(
self,
tasks: Iterable[BatchTask],
prompt_text: str,
*,
chunk_size: int = 20,
) -> Dict[str, str]:
"""Run batch jobs and return mapping ``key -> text``."""
pending = list(tasks)
if not pending:
return {}
results: Dict[str, str] = {}
normalized_chunk_size = max(1, int(chunk_size))
for chunk_idx in range(0, len(pending), normalized_chunk_size):
chunk = pending[chunk_idx : chunk_idx + normalized_chunk_size]
self._process_chunk(chunk, chunk_idx // normalized_chunk_size, prompt_text, results)
return results
def _process_chunk(
self, chunk: List[BatchTask], chunk_index: int, prompt_text: str, results: Dict[str, str]
) -> None:
if not chunk:
return
uploaded_file_names: List[str] = []
uploaded_jsonl_name: Optional[str] = None
try:
for task in chunk:
uploaded = self._client.files.upload(file=task.path)
task.file_uri = uploaded.uri
task.mime_type = uploaded.mime_type or task.mime_type
uploaded_file_names.append(uploaded.name)
with tempfile.TemporaryDirectory() as tmpdir:
jsonl_path = os.path.join(tmpdir, f"batch_input_{chunk_index:03}.jsonl")
self._prepare_chunk_jsonl(chunk, jsonl_path, prompt_text, chunk_index)
uploaded_jsonl = self._upload_jsonl(jsonl_path, f"batch-input-{chunk_index:03}")
uploaded_jsonl_name = uploaded_jsonl.name
batch_name = self._create_batch_job_rest(
model_id=self._model,
input_file_name=uploaded_jsonl_name,
display_name=f"audio-batch-{chunk_index:03}",
)
dest_file_name = self._poll_batch_job(batch_name)
file_content = self._client.files.download(file=dest_file_name)
self._process_results_jsonl_bytes(file_content, results)
finally:
for name in uploaded_file_names:
try:
self._client.files.delete(name=name)
except Exception: # pragma: no cover - cleanup best effort
pass
if uploaded_jsonl_name:
try:
self._client.files.delete(name=uploaded_jsonl_name)
except Exception: # pragma: no cover - cleanup best effort
pass
# ------------------------- JSONL helpers -------------------------
def _prepare_chunk_jsonl(
self, tasks_chunk: List[BatchTask], jsonl_path: str, prompt_text: str, chunk_index: int
) -> None:
os.makedirs(os.path.dirname(jsonl_path), exist_ok=True)
with open(jsonl_path, "w", encoding="utf-8") as f:
for i, task in enumerate(tasks_chunk):
unique_key = task.key or f"chunk{chunk_index:03}_batch_{i:03}"
parts = self._build_parts_for_task(task, prompt_text)
request_entry = {
"key": unique_key,
"request": {
"contents": [
{
"role": "user",
"parts": parts,
}
]
},
}
f.write(json.dumps(request_entry, ensure_ascii=False) + "\n")
@staticmethod
def _build_parts_for_task(task: BatchTask, prompt_text: str) -> List[Dict[str, Any]]:
clean_prompt = (prompt_text or "").strip()
parts: List[Dict[str, Any]] = []
if clean_prompt:
parts.append({"text": clean_prompt})
parts.append(
{
"file_data": {
"mime_type": task.mime_type,
"file_uri": task.file_uri,
}
}
)
return parts
def _upload_jsonl(self, jsonl_path: str, display_name: str):
try:
return self._client.files.upload(
file=jsonl_path,
config=types.UploadFileConfig(display_name=display_name, mime_type="jsonl"),
)
except Exception:
return self._client.files.upload(
file=jsonl_path,
config=types.UploadFileConfig(display_name=display_name),
)
# ------------------------- REST helpers --------------------------
@staticmethod
def _rest_model_name(model_id: str) -> str:
return model_id.replace("models/", "")
def _create_batch_job_rest(self, model_id: str, input_file_name: str, display_name: str) -> str:
url = (
"https://generativelanguage.googleapis.com/v1beta/models/"
f"{self._rest_model_name(model_id)}:batchGenerateContent"
)
headers = {
"x-goog-api-key": self._api_key,
"Content-Type": "application/json",
}
payload = {
"batch": {
"display_name": display_name,
"input_config": {"file_name": input_file_name},
}
}
resp = requests.post(url, headers=headers, json=payload, timeout=60)
if not resp.ok:
raise AIModelError(f"REST create failed: {resp.status_code} {resp.text}")
data = resp.json()
name = data.get("name")
if not name and isinstance(data.get("batch"), dict):
name = data["batch"].get("name")
if not name:
raise AIModelError(f"REST create succeeded but no batch name found. Response: {data}")
return name
def _get_batch_job_rest(self, name: str) -> Dict[str, Any]:
url = f"https://generativelanguage.googleapis.com/v1beta/{name}"
headers = {"x-goog-api-key": self._api_key}
resp = requests.get(url, headers=headers, timeout=60)
if not resp.ok:
raise AIModelError(f"REST get failed: {resp.status_code} {resp.text}")
return resp.json()
@staticmethod
def _extract_state(rest_obj: Dict[str, Any]) -> Optional[str]:
return rest_obj.get("state") or (rest_obj.get("metadata") or {}).get("state") or (rest_obj.get("batch") or {}).get("state")
@staticmethod
def _extract_result_file_name(rest_obj: Dict[str, Any]) -> Optional[str]:
resp = rest_obj.get("response") or {}
dest = resp.get("dest") or {}
return (
dest.get("file_name")
or dest.get("fileName")
or resp.get("file_name")
or resp.get("fileName")
or resp.get("responsesFile")
or resp.get("responses_file")
)
def _poll_batch_job(self, batch_name: str) -> str:
completed_states = {
"BATCH_STATE_SUCCEEDED",
"BATCH_STATE_FAILED",
"BATCH_STATE_CANCELLED",
"BATCH_STATE_EXPIRED",
"BATCH_STATE_PAUSED",
}
while True:
rest_job = self._get_batch_job_rest(batch_name)
state = self._extract_state(rest_job)
if state in completed_states:
break
time.sleep(30)
if state != "BATCH_STATE_SUCCEEDED":
err = rest_job.get("error") or (rest_job.get("response") or {}).get("error")
raise AIModelError(f"Batch job failed with state {state}: {err}")
result_file_name = self._extract_result_file_name(rest_job)
if not result_file_name:
raise AIModelError("Could not locate result file name in REST response")
return result_file_name
# ------------------------- Results processing --------------------
@staticmethod
def _process_results_jsonl_bytes(content_bytes: bytes, results: Dict[str, str]) -> None:
content_str = content_bytes.decode("utf-8", errors="replace")
for line in content_str.splitlines():
if not line.strip():
continue
try:
result = json.loads(line)
except Exception:
continue
key = result.get("key")
if not key:
continue
response_wrapper = result.get("response", {})
if "error" in response_wrapper:
results[key] = f"Error: {response_wrapper['error']}"
continue
candidates = response_wrapper.get("candidates", [])
text: Optional[str] = None
if candidates and "content" in candidates[0]:
parts = candidates[0]["content"].get("parts", [])
for part in parts:
if isinstance(part, dict) and part.get("text"):
text = part["text"]
break
if text is None:
continue
results[key] = text
def guess_mime_type(path: str) -> str:
mime_type, _ = mimetypes.guess_type(path)
return mime_type or "application/octet-stream"