Spaces:
Running
Running
| import hashlib | |
| import json | |
| import re | |
| import time | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |
| from azure.identity import ( | |
| AzureCliCredential, | |
| ChainedTokenCredential, | |
| ManagedIdentityCredential, | |
| get_bearer_token_provider, | |
| ) | |
| from openai import AzureOpenAI | |
| from tqdm import tqdm | |
| from src.config import ( | |
| DEFAULT_MODEL, | |
| MAX_TOKENS, | |
| MAX_TOKENS_THINKING, | |
| TEMPERATURE, | |
| TRAPI_SCOPE, | |
| build_azure_endpoint, | |
| ) | |
| from src.model_registry import get_model_config | |
| from src.generation_utils import fill_template_file | |
| from src.generation_rates import estimate_cost | |
| Message = Dict[str, str] | |
| PromptMessages = List[Message] | |
| class Generator: | |
| def __init__( | |
| self, | |
| model_alias: Optional[str] = None, | |
| use_cache: bool = True, | |
| cache_path: Optional[Union[str, Path]] = None, | |
| max_retries: int = 5, | |
| retry_wait_seconds: int = 90, | |
| ) -> None: | |
| self.model_alias = model_alias or DEFAULT_MODEL | |
| self.model_config = get_model_config(self.model_alias) | |
| self.use_cache = use_cache | |
| self.max_retries = max_retries | |
| self.retry_wait_seconds = retry_wait_seconds | |
| # Usage tracking | |
| self.usage_stats = { | |
| "total_prompt_tokens": 0, | |
| "total_completion_tokens": 0, | |
| "total_cached_tokens": 0, | |
| "total_requests": 0, | |
| "cache_hits": 0, | |
| } | |
| # Request log file (append-only JSONL, one per model) | |
| self.request_log_path = ( | |
| Path("generation_logs") / f"requests_{self.model_alias}.jsonl" | |
| ) | |
| self.request_log_path.parent.mkdir(parents=True, exist_ok=True) | |
| self.cache_path = ( | |
| Path(cache_path) | |
| if cache_path | |
| else Path("generation_logs/prompt_cache.json") | |
| ) | |
| self.cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| self.client = self._build_client() | |
| self.prompt_cache: Dict[str, str] = self._load_prompt_cache() | |
| def _build_client(self) -> AzureOpenAI: | |
| token_provider = get_bearer_token_provider( | |
| ChainedTokenCredential( | |
| AzureCliCredential(), | |
| ManagedIdentityCredential(), | |
| ), | |
| TRAPI_SCOPE, | |
| ) | |
| endpoint = ( | |
| self.model_config.endpoint_override | |
| if self.model_config.endpoint_override | |
| else build_azure_endpoint() | |
| ) | |
| return AzureOpenAI( | |
| azure_endpoint=endpoint, | |
| azure_ad_token_provider=token_provider, | |
| api_version=self.model_config.api_version, | |
| ) | |
| def _load_prompt_cache(self) -> Dict[str, str]: | |
| if not self.use_cache: | |
| return {} | |
| if not self.cache_path.exists(): | |
| return {} | |
| try: | |
| return json.loads(self.cache_path.read_text(encoding="utf-8")) | |
| except Exception: | |
| return {} | |
| def _save_prompt_cache(self) -> None: | |
| if not self.use_cache: | |
| return | |
| self.cache_path.write_text( | |
| json.dumps(self.prompt_cache, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| def _make_cache_key( | |
| model_alias: str, | |
| messages: PromptMessages, | |
| temperature: float, | |
| max_completion_tokens: int, | |
| response_format: Optional[Dict[str, Any]], | |
| ) -> str: | |
| payload = { | |
| "model_alias": model_alias, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_completion_tokens": max_completion_tokens, | |
| "response_format": response_format, | |
| } | |
| raw = json.dumps(payload, sort_keys=True, ensure_ascii=False) | |
| return hashlib.sha256(raw.encode("utf-8")).hexdigest() | |
| def _chat_once( | |
| self, | |
| messages: PromptMessages, | |
| temperature: float, | |
| max_completion_tokens: int, | |
| response_format: Optional[Dict[str, Any]], | |
| ) -> str: | |
| kwargs: Dict[str, Any] = { | |
| "model": self.model_config.deployment_name, | |
| "messages": messages, | |
| "max_completion_tokens": max_completion_tokens, | |
| "temperature": temperature, | |
| } | |
| # Only pass structured response_format to models that support it. | |
| if response_format is not None and self.model_config.is_openai_compatible: | |
| kwargs["response_format"] = response_format | |
| # Disable thinking for Qwen 3.5 models to avoid slow reasoning tokens | |
| if "qwen" in self.model_alias.lower() and "3.5" in self.model_alias: | |
| kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} | |
| response = self.client.chat.completions.create(**kwargs) | |
| content = response.choices[0].message.content or "" | |
| # Capture usage stats | |
| usage = getattr(response, "usage", None) | |
| if usage: | |
| prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 | |
| completion_tokens = getattr(usage, "completion_tokens", 0) or 0 | |
| cached_tokens = 0 | |
| prompt_details = getattr(usage, "prompt_tokens_details", None) | |
| if prompt_details: | |
| cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 | |
| self.usage_stats["total_prompt_tokens"] += prompt_tokens | |
| self.usage_stats["total_completion_tokens"] += completion_tokens | |
| self.usage_stats["total_cached_tokens"] += cached_tokens | |
| self.usage_stats["total_requests"] += 1 | |
| self._log_request( | |
| messages=messages, | |
| temperature=temperature, | |
| max_completion_tokens=max_completion_tokens, | |
| response_format=response_format, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| cached_tokens=cached_tokens, | |
| ) | |
| return content | |
| def _log_request( | |
| self, | |
| messages: PromptMessages, | |
| temperature: float, | |
| max_completion_tokens: int, | |
| response_format: Optional[Dict[str, Any]], | |
| prompt_tokens: int, | |
| completion_tokens: int, | |
| cached_tokens: int, | |
| ) -> None: | |
| """Append a single request record to the JSONL log.""" | |
| record = { | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "model_alias": self.model_alias, | |
| "deployment": self.model_config.deployment_name, | |
| "temperature": temperature, | |
| "max_completion_tokens": max_completion_tokens, | |
| "has_response_format": response_format is not None, | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "cached_tokens": cached_tokens, | |
| "cost_usd": estimate_cost( | |
| self.model_alias, prompt_tokens, completion_tokens, cached_tokens | |
| ), | |
| } | |
| with open(self.request_log_path, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| def chat( | |
| self, | |
| messages: PromptMessages, | |
| temperature: Optional[float] = None, | |
| max_completion_tokens: Optional[int] = None, | |
| response_format: Optional[Dict[str, Any]] = None, | |
| dont_use_cached: bool = False, | |
| ) -> str: | |
| final_temperature = TEMPERATURE if temperature is None else temperature | |
| if max_completion_tokens is not None: | |
| final_max_tokens = max_completion_tokens | |
| elif not self.model_config.is_openai_compatible: | |
| final_max_tokens = MAX_TOKENS_THINKING | |
| else: | |
| final_max_tokens = MAX_TOKENS | |
| cache_key = self._make_cache_key( | |
| model_alias=self.model_alias, | |
| messages=messages, | |
| temperature=final_temperature, | |
| max_completion_tokens=final_max_tokens, | |
| response_format=response_format, | |
| ) | |
| if self.use_cache and not dont_use_cached and cache_key in self.prompt_cache: | |
| self.usage_stats["cache_hits"] += 1 | |
| return self.prompt_cache[cache_key] | |
| last_error: Optional[Exception] = None | |
| for attempt in range(1, self.max_retries + 1): | |
| try: | |
| content = self._chat_once( | |
| messages=messages, | |
| temperature=final_temperature, | |
| max_completion_tokens=final_max_tokens, | |
| response_format=response_format, | |
| ) | |
| if self.use_cache: | |
| self.prompt_cache[cache_key] = content | |
| self._save_prompt_cache() | |
| return content | |
| except Exception as exc: | |
| last_error = exc | |
| is_last_attempt = attempt == self.max_retries | |
| print( | |
| f"[Generator] Model call failed on attempt " | |
| f"{attempt}/{self.max_retries} for model '{self.model_alias}': {exc}" | |
| ) | |
| if is_last_attempt: | |
| break | |
| print( | |
| f"[Generator] Waiting {self.retry_wait_seconds} seconds before retry..." | |
| ) | |
| time.sleep(self.retry_wait_seconds) | |
| raise RuntimeError( | |
| f"Model call failed after {self.max_retries} attempts " | |
| f"for model '{self.model_alias}'." | |
| ) from last_error | |
| def prompt( | |
| self, | |
| prompts: Sequence[PromptMessages], | |
| temperature: Optional[float] = None, | |
| max_completion_tokens: Optional[int] = None, | |
| response_format: Optional[Dict[str, Any]] = None, | |
| dont_use_cached: bool = False, | |
| skip_failures: bool = False, | |
| ) -> List[Optional[str]]: | |
| outputs: List[Optional[str]] = [] | |
| for prompt_messages in tqdm( | |
| prompts, | |
| desc=f"[Generator:{self.model_alias}]", | |
| unit="req", | |
| ): | |
| try: | |
| output = self.chat( | |
| messages=prompt_messages, | |
| temperature=temperature, | |
| max_completion_tokens=max_completion_tokens, | |
| response_format=response_format, | |
| dont_use_cached=dont_use_cached, | |
| ) | |
| outputs.append(output) | |
| except Exception as exc: | |
| if skip_failures: | |
| print(f"[Generator] Skipping failed item: {exc}") | |
| outputs.append(None) | |
| else: | |
| raise | |
| return outputs | |
| def get_usage_summary(self) -> Dict[str, Any]: | |
| """Return usage stats with estimated cost.""" | |
| stats = dict(self.usage_stats) | |
| stats["model_alias"] = self.model_alias | |
| stats["estimated_cost_usd"] = estimate_cost( | |
| self.model_alias, | |
| stats["total_prompt_tokens"], | |
| stats["total_completion_tokens"], | |
| stats["total_cached_tokens"], | |
| ) | |
| return stats | |
| def print_usage_summary(self, stage: str = "") -> None: | |
| """Print a human-readable usage summary.""" | |
| s = self.get_usage_summary() | |
| label = f"[{stage}] " if stage else "" | |
| cost_str = ( | |
| f"${s['estimated_cost_usd']:.4f}" | |
| if s["estimated_cost_usd"] is not None | |
| else "N/A (no pricing)" | |
| ) | |
| print( | |
| f"{label}Usage for {s['model_alias']}: " | |
| f"{s['total_prompt_tokens']} prompt tokens, " | |
| f"{s['total_completion_tokens']} completion tokens, " | |
| f"{s['total_cached_tokens']} cached tokens | " | |
| f"{s['total_requests']} API calls, " | |
| f"{s['cache_hits']} cache hits | " | |
| f"Est. cost: {cost_str}" | |
| ) | |
| def build_prompts( | |
| self, | |
| template_path: Union[str, Path], | |
| data: Sequence[Dict[str, Any]], | |
| ) -> Tuple[List[PromptMessages], Optional[Dict[str, Any]]]: | |
| prompts: List[PromptMessages] = [] | |
| shared_response_format: Optional[Dict[str, Any]] = None | |
| for item in data: | |
| messages, response_format = fill_template_file(str(template_path), item) | |
| prompts.append(messages) | |
| if shared_response_format is None: | |
| shared_response_format = response_format | |
| return prompts, shared_response_format | |
| def parse_json_response(text: str): | |
| text = text.strip() | |
| # Try normal parse first | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| # Strip code fences (e.g. ```json ... ```) common in Gemma responses | |
| fence_pattern = r"```(?:json)?\s*([\s\S]*?)```" | |
| fence_match = re.search(fence_pattern, text) | |
| if fence_match: | |
| try: | |
| return json.loads(fence_match.group(1).strip()) | |
| except Exception: | |
| pass | |
| # Strip common thinking/reasoning prefixes from models like Qwen | |
| # Look for JSON after thinking blocks | |
| thinking_patterns = [ | |
| r"(?s).*?</think>\s*", | |
| r"(?s)^Thinking Process:.*?(?=\{)", | |
| r"(?s)^<think>.*?</think>\s*", | |
| ] | |
| cleaned = text | |
| for pattern in thinking_patterns: | |
| match = re.match(pattern, cleaned) | |
| if match: | |
| cleaned = cleaned[match.end():] | |
| try: | |
| return json.loads(cleaned.strip()) | |
| except Exception: | |
| pass | |
| # Fallback: find the last complete JSON object (most likely the actual output) | |
| # Search from the end to avoid noise from thinking blocks | |
| brace_positions = [i for i, c in enumerate(text) if c == '{'] | |
| for start in reversed(brace_positions): | |
| depth = 0 | |
| in_string = False | |
| escape_next = False | |
| for i in range(start, len(text)): | |
| ch = text[i] | |
| if escape_next: | |
| escape_next = False | |
| continue | |
| if ch == '\\' and in_string: | |
| escape_next = True | |
| continue | |
| if ch == '"' and not escape_next: | |
| in_string = not in_string | |
| continue | |
| if in_string: | |
| continue | |
| if ch == '{': | |
| depth += 1 | |
| elif ch == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| candidate = text[start:i + 1] | |
| try: | |
| return json.loads(candidate) | |
| except Exception: | |
| break | |
| continue | |
| # Original simple fallback | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| json_str = text[start:end + 1] | |
| try: | |
| return json.loads(json_str) | |
| except Exception: | |
| pass | |
| raise ValueError(f"Model response was not valid JSON:\n{text[:500]}") |