import hashlib import json import re import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from azure.identity import ( AzureCliCredential, ChainedTokenCredential, ManagedIdentityCredential, get_bearer_token_provider, ) from openai import AzureOpenAI from tqdm import tqdm from src.config import ( DEFAULT_MODEL, MAX_TOKENS, MAX_TOKENS_THINKING, TEMPERATURE, TRAPI_SCOPE, build_azure_endpoint, ) from src.model_registry import get_model_config from src.generation_utils import fill_template_file from src.generation_rates import estimate_cost Message = Dict[str, str] PromptMessages = List[Message] class Generator: def __init__( self, model_alias: Optional[str] = None, use_cache: bool = True, cache_path: Optional[Union[str, Path]] = None, max_retries: int = 5, retry_wait_seconds: int = 90, ) -> None: self.model_alias = model_alias or DEFAULT_MODEL self.model_config = get_model_config(self.model_alias) self.use_cache = use_cache self.max_retries = max_retries self.retry_wait_seconds = retry_wait_seconds # Usage tracking self.usage_stats = { "total_prompt_tokens": 0, "total_completion_tokens": 0, "total_cached_tokens": 0, "total_requests": 0, "cache_hits": 0, } # Request log file (append-only JSONL, one per model) self.request_log_path = ( Path("generation_logs") / f"requests_{self.model_alias}.jsonl" ) self.request_log_path.parent.mkdir(parents=True, exist_ok=True) self.cache_path = ( Path(cache_path) if cache_path else Path("generation_logs/prompt_cache.json") ) self.cache_path.parent.mkdir(parents=True, exist_ok=True) self.client = self._build_client() self.prompt_cache: Dict[str, str] = self._load_prompt_cache() def _build_client(self) -> AzureOpenAI: token_provider = get_bearer_token_provider( ChainedTokenCredential( AzureCliCredential(), ManagedIdentityCredential(), ), TRAPI_SCOPE, ) endpoint = ( self.model_config.endpoint_override if self.model_config.endpoint_override else build_azure_endpoint() ) return AzureOpenAI( azure_endpoint=endpoint, azure_ad_token_provider=token_provider, api_version=self.model_config.api_version, ) def _load_prompt_cache(self) -> Dict[str, str]: if not self.use_cache: return {} if not self.cache_path.exists(): return {} try: return json.loads(self.cache_path.read_text(encoding="utf-8")) except Exception: return {} def _save_prompt_cache(self) -> None: if not self.use_cache: return self.cache_path.write_text( json.dumps(self.prompt_cache, indent=2, ensure_ascii=False), encoding="utf-8", ) @staticmethod def _make_cache_key( model_alias: str, messages: PromptMessages, temperature: float, max_completion_tokens: int, response_format: Optional[Dict[str, Any]], ) -> str: payload = { "model_alias": model_alias, "messages": messages, "temperature": temperature, "max_completion_tokens": max_completion_tokens, "response_format": response_format, } raw = json.dumps(payload, sort_keys=True, ensure_ascii=False) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _chat_once( self, messages: PromptMessages, temperature: float, max_completion_tokens: int, response_format: Optional[Dict[str, Any]], ) -> str: kwargs: Dict[str, Any] = { "model": self.model_config.deployment_name, "messages": messages, "max_completion_tokens": max_completion_tokens, "temperature": temperature, } # Only pass structured response_format to models that support it. if response_format is not None and self.model_config.is_openai_compatible: kwargs["response_format"] = response_format # Disable thinking for Qwen 3.5 models to avoid slow reasoning tokens if "qwen" in self.model_alias.lower() and "3.5" in self.model_alias: kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} response = self.client.chat.completions.create(**kwargs) content = response.choices[0].message.content or "" # Capture usage stats usage = getattr(response, "usage", None) if usage: prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 completion_tokens = getattr(usage, "completion_tokens", 0) or 0 cached_tokens = 0 prompt_details = getattr(usage, "prompt_tokens_details", None) if prompt_details: cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 self.usage_stats["total_prompt_tokens"] += prompt_tokens self.usage_stats["total_completion_tokens"] += completion_tokens self.usage_stats["total_cached_tokens"] += cached_tokens self.usage_stats["total_requests"] += 1 self._log_request( messages=messages, temperature=temperature, max_completion_tokens=max_completion_tokens, response_format=response_format, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cached_tokens=cached_tokens, ) return content def _log_request( self, messages: PromptMessages, temperature: float, max_completion_tokens: int, response_format: Optional[Dict[str, Any]], prompt_tokens: int, completion_tokens: int, cached_tokens: int, ) -> None: """Append a single request record to the JSONL log.""" record = { "timestamp": datetime.now(timezone.utc).isoformat(), "model_alias": self.model_alias, "deployment": self.model_config.deployment_name, "temperature": temperature, "max_completion_tokens": max_completion_tokens, "has_response_format": response_format is not None, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "cached_tokens": cached_tokens, "cost_usd": estimate_cost( self.model_alias, prompt_tokens, completion_tokens, cached_tokens ), } with open(self.request_log_path, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") def chat( self, messages: PromptMessages, temperature: Optional[float] = None, max_completion_tokens: Optional[int] = None, response_format: Optional[Dict[str, Any]] = None, dont_use_cached: bool = False, ) -> str: final_temperature = TEMPERATURE if temperature is None else temperature if max_completion_tokens is not None: final_max_tokens = max_completion_tokens elif not self.model_config.is_openai_compatible: final_max_tokens = MAX_TOKENS_THINKING else: final_max_tokens = MAX_TOKENS cache_key = self._make_cache_key( model_alias=self.model_alias, messages=messages, temperature=final_temperature, max_completion_tokens=final_max_tokens, response_format=response_format, ) if self.use_cache and not dont_use_cached and cache_key in self.prompt_cache: self.usage_stats["cache_hits"] += 1 return self.prompt_cache[cache_key] last_error: Optional[Exception] = None for attempt in range(1, self.max_retries + 1): try: content = self._chat_once( messages=messages, temperature=final_temperature, max_completion_tokens=final_max_tokens, response_format=response_format, ) if self.use_cache: self.prompt_cache[cache_key] = content self._save_prompt_cache() return content except Exception as exc: last_error = exc is_last_attempt = attempt == self.max_retries print( f"[Generator] Model call failed on attempt " f"{attempt}/{self.max_retries} for model '{self.model_alias}': {exc}" ) if is_last_attempt: break print( f"[Generator] Waiting {self.retry_wait_seconds} seconds before retry..." ) time.sleep(self.retry_wait_seconds) raise RuntimeError( f"Model call failed after {self.max_retries} attempts " f"for model '{self.model_alias}'." ) from last_error def prompt( self, prompts: Sequence[PromptMessages], temperature: Optional[float] = None, max_completion_tokens: Optional[int] = None, response_format: Optional[Dict[str, Any]] = None, dont_use_cached: bool = False, skip_failures: bool = False, ) -> List[Optional[str]]: outputs: List[Optional[str]] = [] for prompt_messages in tqdm( prompts, desc=f"[Generator:{self.model_alias}]", unit="req", ): try: output = self.chat( messages=prompt_messages, temperature=temperature, max_completion_tokens=max_completion_tokens, response_format=response_format, dont_use_cached=dont_use_cached, ) outputs.append(output) except Exception as exc: if skip_failures: print(f"[Generator] Skipping failed item: {exc}") outputs.append(None) else: raise return outputs def get_usage_summary(self) -> Dict[str, Any]: """Return usage stats with estimated cost.""" stats = dict(self.usage_stats) stats["model_alias"] = self.model_alias stats["estimated_cost_usd"] = estimate_cost( self.model_alias, stats["total_prompt_tokens"], stats["total_completion_tokens"], stats["total_cached_tokens"], ) return stats def print_usage_summary(self, stage: str = "") -> None: """Print a human-readable usage summary.""" s = self.get_usage_summary() label = f"[{stage}] " if stage else "" cost_str = ( f"${s['estimated_cost_usd']:.4f}" if s["estimated_cost_usd"] is not None else "N/A (no pricing)" ) print( f"{label}Usage for {s['model_alias']}: " f"{s['total_prompt_tokens']} prompt tokens, " f"{s['total_completion_tokens']} completion tokens, " f"{s['total_cached_tokens']} cached tokens | " f"{s['total_requests']} API calls, " f"{s['cache_hits']} cache hits | " f"Est. cost: {cost_str}" ) def build_prompts( self, template_path: Union[str, Path], data: Sequence[Dict[str, Any]], ) -> Tuple[List[PromptMessages], Optional[Dict[str, Any]]]: prompts: List[PromptMessages] = [] shared_response_format: Optional[Dict[str, Any]] = None for item in data: messages, response_format = fill_template_file(str(template_path), item) prompts.append(messages) if shared_response_format is None: shared_response_format = response_format return prompts, shared_response_format @staticmethod def parse_json_response(text: str): text = text.strip() # Try normal parse first try: return json.loads(text) except Exception: pass # Strip code fences (e.g. ```json ... ```) common in Gemma responses fence_pattern = r"```(?:json)?\s*([\s\S]*?)```" fence_match = re.search(fence_pattern, text) if fence_match: try: return json.loads(fence_match.group(1).strip()) except Exception: pass # Strip common thinking/reasoning prefixes from models like Qwen # Look for JSON after thinking blocks thinking_patterns = [ r"(?s).*?\s*", r"(?s)^Thinking Process:.*?(?=\{)", r"(?s)^.*?\s*", ] cleaned = text for pattern in thinking_patterns: match = re.match(pattern, cleaned) if match: cleaned = cleaned[match.end():] try: return json.loads(cleaned.strip()) except Exception: pass # Fallback: find the last complete JSON object (most likely the actual output) # Search from the end to avoid noise from thinking blocks brace_positions = [i for i, c in enumerate(text) if c == '{'] for start in reversed(brace_positions): depth = 0 in_string = False escape_next = False for i in range(start, len(text)): ch = text[i] if escape_next: escape_next = False continue if ch == '\\' and in_string: escape_next = True continue if ch == '"' and not escape_next: in_string = not in_string continue if in_string: continue if ch == '{': depth += 1 elif ch == '}': depth -= 1 if depth == 0: candidate = text[start:i + 1] try: return json.loads(candidate) except Exception: break continue # Original simple fallback start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: json_str = text[start:end + 1] try: return json.loads(json_str) except Exception: pass raise ValueError(f"Model response was not valid JSON:\n{text[:500]}")