from __future__ import annotations import json import os from pathlib import Path from typing import Any, Protocol, TypeVar from dotenv import load_dotenv from google import genai from google.genai import types from pydantic import BaseModel from .model_schema import ModelMessage try: from trl.chat_template_utils import qwen3_chat_template except Exception: # pragma: no cover - optional runtime dependency qwen3_chat_template = None # type: ignore[assignment] ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel) DEFAULT_GEMINI_DM_MODEL = "gemini-2.5-flash" DEFAULT_GEMINI_HERO_MODEL = "gemini-2.5-flash" DEFAULT_HF_DM_MODEL = "Qwen/Qwen3-32B" DEFAULT_HF_HERO_MODEL = "Qwen/Qwen3-32B" PROVIDER_GEMINI = "gemini" PROVIDER_HF_LOCAL = "hf_local" class StructuredModelClient(Protocol): def generate_structured( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: ... class GeminiStructuredClient: def __init__(self, api_key: str | None = None) -> None: self._client = self._create_client(api_key) def generate_structured( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: failures: list[str] = [] strategies = ( self._generate_with_response_schema, self._generate_with_json_mode, self._generate_with_prompt_only, ) for strategy in strategies: try: return strategy( messages, response_model, model_name=model_name, temperature=temperature, max_output_tokens=max_output_tokens, ) except Exception as exc: failures.append(f"{strategy.__name__}: {self._normalize_error(exc)}") raise RuntimeError("Gemini structured generation failed. " + " | ".join(failures)) def _generate_with_response_schema( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: system_instruction, contents = self._split_messages(messages) response = self._client.models.generate_content( model=model_name, contents=contents, config=types.GenerateContentConfig( system_instruction=system_instruction, temperature=temperature, max_output_tokens=max_output_tokens, response_mime_type="application/json", response_schema=response_model, candidate_count=1, ), ) parsed = getattr(response, "parsed", None) if parsed is not None: return response_model.model_validate(parsed) text = getattr(response, "text", None) if isinstance(text, str) and text.strip(): return response_model.model_validate_json(text) raise RuntimeError("Gemini returned an empty structured response.") def _generate_with_json_mode( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: prompt = self._json_prompt(messages, response_model) response = self._client.models.generate_content( model=model_name, contents=prompt, config=types.GenerateContentConfig( temperature=temperature, max_output_tokens=max_output_tokens, response_mime_type="application/json", candidate_count=1, ), ) text = getattr(response, "text", None) if not isinstance(text, str) or not text.strip(): raise RuntimeError("Gemini returned an empty JSON-mode response.") return response_model.model_validate_json(text) def _generate_with_prompt_only( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: prompt = self._json_prompt(messages, response_model) response = self._client.models.generate_content( model=model_name, contents=prompt, config=types.GenerateContentConfig( temperature=temperature, max_output_tokens=max_output_tokens, candidate_count=1, ), ) text = getattr(response, "text", None) if not isinstance(text, str) or not text.strip(): raise RuntimeError("Gemini returned an empty prompt-only response.") return response_model.model_validate_json(self._extract_json_object(text)) def _create_client(self, api_key: str | None) -> genai.Client: load_dotenv(self._repo_root() / ".env", override=False) key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") if not key: raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY.") return genai.Client(api_key=key) @staticmethod def _repo_root() -> Path: return Path(__file__).resolve().parents[2] @staticmethod def _split_messages(messages: list[ModelMessage]) -> tuple[str | None, list[str]]: system_parts: list[str] = [] content_parts: list[str] = [] for message in messages: if message.role == "system": system_parts.append(message.content) continue content_parts.append(f"{message.role.upper()}:\n{message.content}") system_instruction = "\n\n".join(system_parts) if system_parts else None contents = ["\n\n".join(content_parts)] if content_parts else [""] return system_instruction, contents @staticmethod def _json_prompt( messages: list[ModelMessage], response_model: type[ResponseModelT], ) -> str: message_blocks = [f"{message.role.upper()}:\n{message.content}" for message in messages] schema = _schema_prompt_snippet(response_model) conversation = "\n\n".join(message_blocks) return ( "Return exactly one valid JSON object and nothing else.\n" "Do not use markdown fences.\n" "Use compact JSON with no commentary.\n" f"JSON Schema:\n{schema}\n\n" f"Conversation:\n{conversation}\n" ) @staticmethod def _extract_json_object(text: str) -> str: cleaned = text.strip() if cleaned.startswith("```"): cleaned = cleaned.strip("`") if cleaned.startswith("json"): cleaned = cleaned[4:].lstrip() start = cleaned.find("{") end = cleaned.rfind("}") if start == -1 or end == -1 or end < start: raise RuntimeError("Gemini response did not contain a JSON object.") return cleaned[start : end + 1] @staticmethod def _normalize_error(exc: Exception) -> str: return " ".join(str(exc).split()) or exc.__class__.__name__ class HuggingFaceStructuredClient: def __init__( self, *, adapter_path: str | None = None, cache_dir: str | None = None, load_in_4bit: bool = True, trust_remote_code: bool = False, device_map: str | None = "auto", ) -> None: self.adapter_path = adapter_path self.cache_dir = cache_dir self.load_in_4bit = load_in_4bit self.trust_remote_code = trust_remote_code self.device_map = device_map self._loaded_model_name: str | None = None self._model: Any | None = None self._tokenizer: Any | None = None def generate_structured( self, messages: list[ModelMessage], response_model: type[ResponseModelT], *, model_name: str, temperature: float, max_output_tokens: int, ) -> ResponseModelT: tokenizer, model = self._ensure_model(model_name) prompt = self._hf_prompt(messages, response_model) rendered = self._render_prompt(tokenizer, prompt) tokenized = tokenizer(rendered, return_tensors="pt") tokenized = {key: value.to(model.device) for key, value in tokenized.items()} generate_kwargs: dict[str, Any] = { "max_new_tokens": max_output_tokens, "do_sample": temperature > 0.0, "temperature": max(temperature, 1e-5) if temperature > 0.0 else None, "pad_token_id": getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None), "eos_token_id": getattr(tokenizer, "eos_token_id", None), } generate_kwargs = {key: value for key, value in generate_kwargs.items() if value is not None} import torch with torch.inference_mode(): output_ids = model.generate(**tokenized, **generate_kwargs) prompt_length = tokenized["input_ids"].shape[1] completion_ids = output_ids[0][prompt_length:] text = tokenizer.decode(completion_ids, skip_special_tokens=True) if not text.strip(): raise RuntimeError("Hugging Face model returned an empty response.") return response_model.model_validate_json(self._extract_json_object(text)) def _ensure_model(self, model_name: str) -> tuple[Any, Any]: if self._model is not None and self._tokenizer is not None and self._loaded_model_name == model_name: return self._tokenizer, self._model load_dotenv(self._repo_root() / ".env", override=False) from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=self.cache_dir, trust_remote_code=self.trust_remote_code, token=_hf_token(), ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer = self._canonicalize_chat_template(tokenizer) model_kwargs: dict[str, Any] = { "cache_dir": self.cache_dir, "trust_remote_code": self.trust_remote_code, "token": _hf_token(), } model_kwargs.update(_hf_model_init_kwargs(load_in_4bit=self.load_in_4bit, device_map=self.device_map)) model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) if self.adapter_path: from peft import PeftModel model = PeftModel.from_pretrained(model, self.adapter_path, is_trainable=False) model.eval() self._loaded_model_name = model_name self._model = model self._tokenizer = tokenizer return tokenizer, model @staticmethod def _repo_root() -> Path: return Path(__file__).resolve().parents[2] @staticmethod def _render_prompt(tokenizer: Any, prompt: str) -> str: if hasattr(tokenizer, "apply_chat_template"): chat_template_kwargs = HuggingFaceStructuredClient._chat_template_kwargs(tokenizer) return tokenizer.apply_chat_template( [ {"role": "system", "content": "Return exactly one valid JSON object and nothing else."}, {"role": "user", "content": prompt}, ], tokenize=False, add_generation_prompt=True, **chat_template_kwargs, ) return prompt @staticmethod def _canonicalize_chat_template(tokenizer: Any) -> Any: chat_template = getattr(tokenizer, "chat_template", "") or "" if qwen3_chat_template is None: return tokenizer if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template: return tokenizer tokenizer.chat_template = qwen3_chat_template return tokenizer @staticmethod def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any]: if not hasattr(tokenizer, "apply_chat_template"): return {} try: tokenizer.apply_chat_template( [{"role": "user", "content": "ping"}], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except Exception: return {} return {"enable_thinking": False} @staticmethod def _hf_prompt( messages: list[ModelMessage], response_model: type[ResponseModelT], ) -> str: schema = _schema_prompt_snippet(response_model) conversation = "\n\n".join(f"{message.role.upper()}:\n{message.content}" for message in messages) return ( "Respond with exactly one compact JSON object and no other text.\n" "Do not use markdown fences.\n" f"JSON Schema:\n{schema}\n\n" f"Conversation:\n{conversation}\n" ) @staticmethod def _extract_json_object(text: str) -> str: cleaned = text.strip() if cleaned.startswith("```"): cleaned = cleaned.strip("`") if cleaned.startswith("json"): cleaned = cleaned[4:].lstrip() start = cleaned.find("{") end = cleaned.rfind("}") if start == -1 or end == -1 or end < start: raise RuntimeError("Hugging Face response did not contain a JSON object.") return cleaned[start : end + 1] def _schema_prompt_snippet(response_model: type[ResponseModelT]) -> str: schema = response_model.model_json_schema() serialized = json.dumps(schema, separators=(",", ":")) if len(serialized) <= 4000: return serialized summarized = { "title": schema.get("title", response_model.__name__), "type": schema.get("type", "object"), "required": schema.get("required", []), "properties": { key: { field_name: value for field_name, value in property_schema.items() if field_name in {"type", "title", "enum", "items", "required", "$ref", "description"} } for key, property_schema in schema.get("properties", {}).items() }, "defs": sorted(schema.get("$defs", {}).keys()), } return json.dumps(summarized, separators=(",", ":")) def _hf_model_init_kwargs(*, load_in_4bit: bool, device_map: str | None) -> dict[str, Any]: import torch kwargs: dict[str, Any] = { "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, } if device_map is not None and torch.cuda.is_available(): kwargs["device_map"] = device_map if load_in_4bit and torch.cuda.is_available(): from transformers import BitsAndBytesConfig kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) return kwargs def _hf_token() -> str | None: return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")