Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Protocol, TypeVar | |
| from dotenv import load_dotenv | |
| from google import genai | |
| from google.genai import types | |
| from pydantic import BaseModel | |
| from .model_schema import ModelMessage | |
| try: | |
| from trl.chat_template_utils import qwen3_chat_template | |
| except Exception: # pragma: no cover - optional runtime dependency | |
| qwen3_chat_template = None # type: ignore[assignment] | |
| ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel) | |
| DEFAULT_GEMINI_DM_MODEL = "gemini-2.5-flash" | |
| DEFAULT_GEMINI_HERO_MODEL = "gemini-2.5-flash" | |
| DEFAULT_HF_DM_MODEL = "Qwen/Qwen3-32B" | |
| DEFAULT_HF_HERO_MODEL = "Qwen/Qwen3-32B" | |
| PROVIDER_GEMINI = "gemini" | |
| PROVIDER_HF_LOCAL = "hf_local" | |
| class StructuredModelClient(Protocol): | |
| def generate_structured( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| ... | |
| class GeminiStructuredClient: | |
| def __init__(self, api_key: str | None = None) -> None: | |
| self._client = self._create_client(api_key) | |
| def generate_structured( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| failures: list[str] = [] | |
| strategies = ( | |
| self._generate_with_response_schema, | |
| self._generate_with_json_mode, | |
| self._generate_with_prompt_only, | |
| ) | |
| for strategy in strategies: | |
| try: | |
| return strategy( | |
| messages, | |
| response_model, | |
| model_name=model_name, | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| ) | |
| except Exception as exc: | |
| failures.append(f"{strategy.__name__}: {self._normalize_error(exc)}") | |
| raise RuntimeError("Gemini structured generation failed. " + " | ".join(failures)) | |
| def _generate_with_response_schema( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| system_instruction, contents = self._split_messages(messages) | |
| response = self._client.models.generate_content( | |
| model=model_name, | |
| contents=contents, | |
| config=types.GenerateContentConfig( | |
| system_instruction=system_instruction, | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| response_mime_type="application/json", | |
| response_schema=response_model, | |
| candidate_count=1, | |
| ), | |
| ) | |
| parsed = getattr(response, "parsed", None) | |
| if parsed is not None: | |
| return response_model.model_validate(parsed) | |
| text = getattr(response, "text", None) | |
| if isinstance(text, str) and text.strip(): | |
| return response_model.model_validate_json(text) | |
| raise RuntimeError("Gemini returned an empty structured response.") | |
| def _generate_with_json_mode( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| prompt = self._json_prompt(messages, response_model) | |
| response = self._client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| response_mime_type="application/json", | |
| candidate_count=1, | |
| ), | |
| ) | |
| text = getattr(response, "text", None) | |
| if not isinstance(text, str) or not text.strip(): | |
| raise RuntimeError("Gemini returned an empty JSON-mode response.") | |
| return response_model.model_validate_json(text) | |
| def _generate_with_prompt_only( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| prompt = self._json_prompt(messages, response_model) | |
| response = self._client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config=types.GenerateContentConfig( | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| candidate_count=1, | |
| ), | |
| ) | |
| text = getattr(response, "text", None) | |
| if not isinstance(text, str) or not text.strip(): | |
| raise RuntimeError("Gemini returned an empty prompt-only response.") | |
| return response_model.model_validate_json(self._extract_json_object(text)) | |
| def _create_client(self, api_key: str | None) -> genai.Client: | |
| load_dotenv(self._repo_root() / ".env", override=False) | |
| key = api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
| if not key: | |
| raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY.") | |
| return genai.Client(api_key=key) | |
| def _repo_root() -> Path: | |
| return Path(__file__).resolve().parents[2] | |
| def _split_messages(messages: list[ModelMessage]) -> tuple[str | None, list[str]]: | |
| system_parts: list[str] = [] | |
| content_parts: list[str] = [] | |
| for message in messages: | |
| if message.role == "system": | |
| system_parts.append(message.content) | |
| continue | |
| content_parts.append(f"{message.role.upper()}:\n{message.content}") | |
| system_instruction = "\n\n".join(system_parts) if system_parts else None | |
| contents = ["\n\n".join(content_parts)] if content_parts else [""] | |
| return system_instruction, contents | |
| def _json_prompt( | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| ) -> str: | |
| message_blocks = [f"{message.role.upper()}:\n{message.content}" for message in messages] | |
| schema = _schema_prompt_snippet(response_model) | |
| conversation = "\n\n".join(message_blocks) | |
| return ( | |
| "Return exactly one valid JSON object and nothing else.\n" | |
| "Do not use markdown fences.\n" | |
| "Use compact JSON with no commentary.\n" | |
| f"JSON Schema:\n{schema}\n\n" | |
| f"Conversation:\n{conversation}\n" | |
| ) | |
| def _extract_json_object(text: str) -> str: | |
| cleaned = text.strip() | |
| if cleaned.startswith("```"): | |
| cleaned = cleaned.strip("`") | |
| if cleaned.startswith("json"): | |
| cleaned = cleaned[4:].lstrip() | |
| start = cleaned.find("{") | |
| end = cleaned.rfind("}") | |
| if start == -1 or end == -1 or end < start: | |
| raise RuntimeError("Gemini response did not contain a JSON object.") | |
| return cleaned[start : end + 1] | |
| def _normalize_error(exc: Exception) -> str: | |
| return " ".join(str(exc).split()) or exc.__class__.__name__ | |
| class HuggingFaceStructuredClient: | |
| def __init__( | |
| self, | |
| *, | |
| adapter_path: str | None = None, | |
| cache_dir: str | None = None, | |
| load_in_4bit: bool = True, | |
| trust_remote_code: bool = False, | |
| device_map: str | None = "auto", | |
| ) -> None: | |
| self.adapter_path = adapter_path | |
| self.cache_dir = cache_dir | |
| self.load_in_4bit = load_in_4bit | |
| self.trust_remote_code = trust_remote_code | |
| self.device_map = device_map | |
| self._loaded_model_name: str | None = None | |
| self._model: Any | None = None | |
| self._tokenizer: Any | None = None | |
| def generate_structured( | |
| self, | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| *, | |
| model_name: str, | |
| temperature: float, | |
| max_output_tokens: int, | |
| ) -> ResponseModelT: | |
| tokenizer, model = self._ensure_model(model_name) | |
| prompt = self._hf_prompt(messages, response_model) | |
| rendered = self._render_prompt(tokenizer, prompt) | |
| tokenized = tokenizer(rendered, return_tensors="pt") | |
| tokenized = {key: value.to(model.device) for key, value in tokenized.items()} | |
| generate_kwargs: dict[str, Any] = { | |
| "max_new_tokens": max_output_tokens, | |
| "do_sample": temperature > 0.0, | |
| "temperature": max(temperature, 1e-5) if temperature > 0.0 else None, | |
| "pad_token_id": getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None), | |
| "eos_token_id": getattr(tokenizer, "eos_token_id", None), | |
| } | |
| generate_kwargs = {key: value for key, value in generate_kwargs.items() if value is not None} | |
| import torch | |
| with torch.inference_mode(): | |
| output_ids = model.generate(**tokenized, **generate_kwargs) | |
| prompt_length = tokenized["input_ids"].shape[1] | |
| completion_ids = output_ids[0][prompt_length:] | |
| text = tokenizer.decode(completion_ids, skip_special_tokens=True) | |
| if not text.strip(): | |
| raise RuntimeError("Hugging Face model returned an empty response.") | |
| return response_model.model_validate_json(self._extract_json_object(text)) | |
| def _ensure_model(self, model_name: str) -> tuple[Any, Any]: | |
| if self._model is not None and self._tokenizer is not None and self._loaded_model_name == model_name: | |
| return self._tokenizer, self._model | |
| load_dotenv(self._repo_root() / ".env", override=False) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=self.cache_dir, | |
| trust_remote_code=self.trust_remote_code, | |
| token=_hf_token(), | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer = self._canonicalize_chat_template(tokenizer) | |
| model_kwargs: dict[str, Any] = { | |
| "cache_dir": self.cache_dir, | |
| "trust_remote_code": self.trust_remote_code, | |
| "token": _hf_token(), | |
| } | |
| model_kwargs.update(_hf_model_init_kwargs(load_in_4bit=self.load_in_4bit, device_map=self.device_map)) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | |
| if self.adapter_path: | |
| from peft import PeftModel | |
| model = PeftModel.from_pretrained(model, self.adapter_path, is_trainable=False) | |
| model.eval() | |
| self._loaded_model_name = model_name | |
| self._model = model | |
| self._tokenizer = tokenizer | |
| return tokenizer, model | |
| def _repo_root() -> Path: | |
| return Path(__file__).resolve().parents[2] | |
| def _render_prompt(tokenizer: Any, prompt: str) -> str: | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| chat_template_kwargs = HuggingFaceStructuredClient._chat_template_kwargs(tokenizer) | |
| return tokenizer.apply_chat_template( | |
| [ | |
| {"role": "system", "content": "Return exactly one valid JSON object and nothing else."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| **chat_template_kwargs, | |
| ) | |
| return prompt | |
| def _canonicalize_chat_template(tokenizer: Any) -> Any: | |
| chat_template = getattr(tokenizer, "chat_template", "") or "" | |
| if qwen3_chat_template is None: | |
| return tokenizer | |
| if "<|im_start|>" not in chat_template or "<|im_end|>" not in chat_template: | |
| return tokenizer | |
| tokenizer.chat_template = qwen3_chat_template | |
| return tokenizer | |
| def _chat_template_kwargs(tokenizer: Any) -> dict[str, Any]: | |
| if not hasattr(tokenizer, "apply_chat_template"): | |
| return {} | |
| try: | |
| tokenizer.apply_chat_template( | |
| [{"role": "user", "content": "ping"}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except Exception: | |
| return {} | |
| return {"enable_thinking": False} | |
| def _hf_prompt( | |
| messages: list[ModelMessage], | |
| response_model: type[ResponseModelT], | |
| ) -> str: | |
| schema = _schema_prompt_snippet(response_model) | |
| conversation = "\n\n".join(f"{message.role.upper()}:\n{message.content}" for message in messages) | |
| return ( | |
| "Respond with exactly one compact JSON object and no other text.\n" | |
| "Do not use markdown fences.\n" | |
| f"JSON Schema:\n{schema}\n\n" | |
| f"Conversation:\n{conversation}\n" | |
| ) | |
| def _extract_json_object(text: str) -> str: | |
| cleaned = text.strip() | |
| if cleaned.startswith("```"): | |
| cleaned = cleaned.strip("`") | |
| if cleaned.startswith("json"): | |
| cleaned = cleaned[4:].lstrip() | |
| start = cleaned.find("{") | |
| end = cleaned.rfind("}") | |
| if start == -1 or end == -1 or end < start: | |
| raise RuntimeError("Hugging Face response did not contain a JSON object.") | |
| return cleaned[start : end + 1] | |
| def _schema_prompt_snippet(response_model: type[ResponseModelT]) -> str: | |
| schema = response_model.model_json_schema() | |
| serialized = json.dumps(schema, separators=(",", ":")) | |
| if len(serialized) <= 4000: | |
| return serialized | |
| summarized = { | |
| "title": schema.get("title", response_model.__name__), | |
| "type": schema.get("type", "object"), | |
| "required": schema.get("required", []), | |
| "properties": { | |
| key: { | |
| field_name: value | |
| for field_name, value in property_schema.items() | |
| if field_name in {"type", "title", "enum", "items", "required", "$ref", "description"} | |
| } | |
| for key, property_schema in schema.get("properties", {}).items() | |
| }, | |
| "defs": sorted(schema.get("$defs", {}).keys()), | |
| } | |
| return json.dumps(summarized, separators=(",", ":")) | |
| def _hf_model_init_kwargs(*, load_in_4bit: bool, device_map: str | None) -> dict[str, Any]: | |
| import torch | |
| kwargs: dict[str, Any] = { | |
| "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| } | |
| if device_map is not None and torch.cuda.is_available(): | |
| kwargs["device_map"] = device_map | |
| if load_in_4bit and torch.cuda.is_available(): | |
| from transformers import BitsAndBytesConfig | |
| kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| return kwargs | |
| def _hf_token() -> str | None: | |
| return os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |