FATHOM-Hero / agents /shared /llm_client.py
aarushgupta's picture
Deploy FATHOM-Hero Space bundle
c782fbf verified
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, Protocol, TypeVar
from dotenv import load_dotenv
from google import genai
from google.genai import types
from pydantic import BaseModel
from .model_schema import ModelMessage
try:
from trl.chat_template_utils import qwen3_chat_template
except Exception: # pragma: no cover - optional runtime dependency
qwen3_chat_template = None # type: ignore[assignment]
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel)
DEFAULT_GEMINI_DM_MODEL = "gemini-2.5-flash"
DEFAULT_GEMINI_HERO_MODEL = "gemini-2.5-flash"
DEFAULT_HF_DM_MODEL = "Qwen/Qwen3-32B"
DEFAULT_HF_HERO_MODEL = "Qwen/Qwen3-32B"
PROVIDER_GEMINI = "gemini"
PROVIDER_HF_LOCAL = "hf_local"
class StructuredModelClient(Protocol):
def generate_structured(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
...
class GeminiStructuredClient:
def __init__(self, api_key: str | None = None) -> None:
self._client = self._create_client(api_key)
def generate_structured(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
failures: list[str] = []
strategies = (
self._generate_with_response_schema,
self._generate_with_json_mode,
self._generate_with_prompt_only,
)
for strategy in strategies:
try:
return strategy(
messages,
response_model,
model_name=model_name,
temperature=temperature,
max_output_tokens=max_output_tokens,
)
except Exception as exc:
failures.append(f"{strategy.__name__}: {self._normalize_error(exc)}")
raise RuntimeError("Gemini structured generation failed. " + " | ".join(failures))
def _generate_with_response_schema(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
system_instruction, contents = self._split_messages(messages)
response = self._client.models.generate_content(
model=model_name,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
max_output_tokens=max_output_tokens,
response_mime_type="application/json",
response_schema=response_model,
candidate_count=1,
),
)
parsed = getattr(response, "parsed", None)
if parsed is not None:
return response_model.model_validate(parsed)
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return response_model.model_validate_json(text)
raise RuntimeError("Gemini returned an empty structured response.")
def _generate_with_json_mode(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
prompt = self._json_prompt(messages, response_model)
response = self._client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=temperature,
max_output_tokens=max_output_tokens,
response_mime_type="application/json",
candidate_count=1,
),
)
text = getattr(response, "text", None)
if not isinstance(text, str) or not text.strip():
raise RuntimeError("Gemini returned an empty JSON-mode response.")
return response_model.model_validate_json(text)
def _generate_with_prompt_only(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
prompt = self._json_prompt(messages, response_model)
response = self._client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=temperature,
max_output_tokens=max_output_tokens,
candidate_count=1,
),
)
text = getattr(response, "text", None)
if not isinstance(text, str) or not text.strip():
raise RuntimeError("Gemini returned an empty prompt-only response.")
return response_model.model_validate_json(self._extract_json_object(text))
def _create_client(self, api_key: str | None) -> genai.Client:
load_dotenv(self._repo_root() / ".env", override=False)
key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not key:
raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY.")
return genai.Client(api_key=key)
@staticmethod
def _repo_root() -> Path:
return Path(__file__).resolve().parents[2]
@staticmethod
def _split_messages(messages: list[ModelMessage]) -> tuple[str | None, list[str]]:
system_parts: list[str] = []
content_parts: list[str] = []
for message in messages:
if message.role == "system":
system_parts.append(message.content)
continue
content_parts.append(f"{message.role.upper()}:\n{message.content}")
system_instruction = "\n\n".join(system_parts) if system_parts else None
contents = ["\n\n".join(content_parts)] if content_parts else [""]
return system_instruction, contents
@staticmethod
def _json_prompt(
messages: list[ModelMessage],
response_model: type[ResponseModelT],
) -> str:
message_blocks = [f"{message.role.upper()}:\n{message.content}" for message in messages]
schema = _schema_prompt_snippet(response_model)
conversation = "\n\n".join(message_blocks)
return (
"Return exactly one valid JSON object and nothing else.\n"
"Do not use markdown fences.\n"
"Use compact JSON with no commentary.\n"
f"JSON Schema:\n{schema}\n\n"
f"Conversation:\n{conversation}\n"
)
@staticmethod
def _extract_json_object(text: str) -> str:
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.strip("`")
if cleaned.startswith("json"):
cleaned = cleaned[4:].lstrip()
start = cleaned.find("{")
end = cleaned.rfind("}")
if start == -1 or end == -1 or end < start:
raise RuntimeError("Gemini response did not contain a JSON object.")
return cleaned[start : end + 1]
@staticmethod
def _normalize_error(exc: Exception) -> str:
return " ".join(str(exc).split()) or exc.__class__.__name__
class HuggingFaceStructuredClient:
def __init__(
self,
*,
adapter_path: str | None = None,
cache_dir: str | None = None,
load_in_4bit: bool = True,
trust_remote_code: bool = False,
device_map: str | None = "auto",
) -> None:
self.adapter_path = adapter_path
self.cache_dir = cache_dir
self.load_in_4bit = load_in_4bit
self.trust_remote_code = trust_remote_code
self.device_map = device_map
self._loaded_model_name: str | None = None
self._model: Any | None = None
self._tokenizer: Any | None = None
def generate_structured(
self,
messages: list[ModelMessage],
response_model: type[ResponseModelT],
*,
model_name: str,
temperature: float,
max_output_tokens: int,
) -> ResponseModelT:
tokenizer, model = self._ensure_model(model_name)
prompt = self._hf_prompt(messages, response_model)
rendered = self._render_prompt(tokenizer, prompt)
tokenized = tokenizer(rendered, return_tensors="pt")
tokenized = {key: value.to(model.device) for key, value in tokenized.items()}
generate_kwargs: dict[str, Any] = {
"max_new_tokens": max_output_tokens,
"do_sample": temperature > 0.0,
"temperature": max(temperature, 1e-5) if temperature > 0.0 else None,
"pad_token_id": getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None),
"eos_token_id": getattr(tokenizer, "eos_token_id", None),
}
generate_kwargs = {key: value for key, value in generate_kwargs.items() if value is not None}
import torch
with torch.inference_mode():
output_ids = model.generate(**tokenized, **generate_kwargs)
prompt_length = tokenized["input_ids"].shape[1]
completion_ids = output_ids[0][prompt_length:]
text = tokenizer.decode(completion_ids, skip_special_tokens=True)
if not text.strip():
raise RuntimeError("Hugging Face model returned an empty response.")
return response_model.model_validate_json(self._extract_json_object(text))
def _ensure_model(self, model_name: str) -> tuple[Any, Any]:
if self._model is not None and self._tokenizer is not None and self._loaded_model_name == model_name:
return self._tokenizer, self._model
load_dotenv(self._repo_root() / ".env", override=False)
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=self.cache_dir,
trust_remote_code=self.trust_remote_code,
token=_hf_token(),
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = self._canonicalize_chat_template(tokenizer)
model_kwargs: dict[str, Any] = {
"cache_dir": self.cache_dir,
"trust_remote_code": self.trust_remote_code,
"token": _hf_token(),
}
model_kwargs.update(_hf_model_init_kwargs(load_in_4bit=self.load_in_4bit, device_map=self.device_map))
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
if self.adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.adapter_path, is_trainable=False)
model.eval()
self._loaded_model_name = model_name
self._model = model
self._tokenizer = tokenizer
return tokenizer, model
@staticmethod
def _repo_root() -> Path:
return Path(__file__).resolve().parents[2]
@staticmethod
def _render_prompt(tokenizer: Any, prompt: str) -> str:
if hasattr(tokenizer, "apply_chat_template"):
chat_template_kwargs = HuggingFaceStructuredClient._chat_template_kwargs(tokenizer)
return tokenizer.apply_chat_template(
[
{"role": "system", "content": "Return exactly one valid JSON object and nothing else."},
{"role": "user", "content": prompt},
],
tokenize=False,
add_generation_prompt=True,
**chat_template_kwargs,
)
return prompt
@staticmethod
def _canonicalize_chat_template(tokenizer: Any) -> Any:
chat_template = getattr(tokenizer, "chat_template", "") or ""
if qwen3_chat_template is None:
return tokenizer
if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template:
return tokenizer
tokenizer.chat_template = qwen3_chat_template
return tokenizer
@staticmethod
def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any]:
if not hasattr(tokenizer, "apply_chat_template"):
return {}
try:
tokenizer.apply_chat_template(
[{"role": "user", "content": "ping"}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except Exception:
return {}
return {"enable_thinking": False}
@staticmethod
def _hf_prompt(
messages: list[ModelMessage],
response_model: type[ResponseModelT],
) -> str:
schema = _schema_prompt_snippet(response_model)
conversation = "\n\n".join(f"{message.role.upper()}:\n{message.content}" for message in messages)
return (
"Respond with exactly one compact JSON object and no other text.\n"
"Do not use markdown fences.\n"
f"JSON Schema:\n{schema}\n\n"
f"Conversation:\n{conversation}\n"
)
@staticmethod
def _extract_json_object(text: str) -> str:
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.strip("`")
if cleaned.startswith("json"):
cleaned = cleaned[4:].lstrip()
start = cleaned.find("{")
end = cleaned.rfind("}")
if start == -1 or end == -1 or end < start:
raise RuntimeError("Hugging Face response did not contain a JSON object.")
return cleaned[start : end + 1]
def _schema_prompt_snippet(response_model: type[ResponseModelT]) -> str:
schema = response_model.model_json_schema()
serialized = json.dumps(schema, separators=(",", ":"))
if len(serialized) <= 4000:
return serialized
summarized = {
"title": schema.get("title", response_model.__name__),
"type": schema.get("type", "object"),
"required": schema.get("required", []),
"properties": {
key: {
field_name: value
for field_name, value in property_schema.items()
if field_name in {"type", "title", "enum", "items", "required", "$ref", "description"}
}
for key, property_schema in schema.get("properties", {}).items()
},
"defs": sorted(schema.get("$defs", {}).keys()),
}
return json.dumps(summarized, separators=(",", ":"))
def _hf_model_init_kwargs(*, load_in_4bit: bool, device_map: str | None) -> dict[str, Any]:
import torch
kwargs: dict[str, Any] = {
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
}
if device_map is not None and torch.cuda.is_available():
kwargs["device_map"] = device_map
if load_in_4bit and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
return kwargs
def _hf_token() -> str | None:
return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")