Spaces:
Sleeping
Sleeping
| """ | |
| LLM helpers: OpenAI client, call_llm with retry, extract_json from response. | |
| """ | |
| import json | |
| import os | |
| import re | |
| import time | |
| from openai import OpenAI | |
| _client: OpenAI | None = None | |
| def get_client() -> OpenAI: | |
| global _client | |
| if _client is None: | |
| key = os.environ.get("OPENAI_API_KEY") | |
| if not key: | |
| raise ValueError("OPENAI_API_KEY environment variable is not set") | |
| _client = OpenAI(api_key=key) | |
| return _client | |
| def extract_json(raw: str) -> dict: | |
| """Extract JSON from LLM output, handling optional markdown code blocks, trailing commas, and missing commas.""" | |
| text = (raw or "").strip() | |
| # Remove optional ```json ... ``` wrapper | |
| match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text) | |
| if match: | |
| text = match.group(1).strip() | |
| # Strip illegal trailing commas (e.g. ", }" or ", ]") so strict json.loads succeeds | |
| text = re.sub(r",(\s*[}\]])", r"\1", text) | |
| # Fix missing comma between array elements (common LLM error in long lists): "}\n {" -> "},\n {" | |
| text = re.sub(r"\}(\s*)\{", r"}, \1{", text) | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError as e: | |
| # Retry after fixing common LLM missing-comma patterns | |
| if "Expecting ',' delimiter" in str(e): | |
| # Missing comma after string value, before next key | |
| text = re.sub(r'"(\s*)\n(\s*)"([^"]+)":', r'",\n\2"\3":', text) | |
| # Missing comma after number, true/false/null, before next key | |
| text = re.sub(r'(\d+)(\s*)\n(\s*)"([^"]+)":', r'\1,\2\n\3"\4":', text) | |
| text = re.sub(r'(true|false|null)(\s*)\n(\s*)"([^"]+)":', r'\1,\2\n\3"\4":', text) | |
| # Missing comma between array elements: ]\n[ or ]\n{ | |
| text = re.sub(r'\](\s*)\n(\s*)\[', r'],\1\n\2[', text) | |
| text = re.sub(r'\](\s*)\n(\s*)\{', r'],\1\n\2{', text) | |
| # Missing comma after } (object end) before next key in same object | |
| text = re.sub(r'\}(\s*)\n(\s*)"([^"]+)":', r'},\1\n\2"\3":', text) | |
| return json.loads(text) | |
| raise | |
| def call_llm( | |
| prompt: str, | |
| model: str = "gpt-4o-mini", | |
| temperature: float = 0.85, | |
| max_retries: int = 3, | |
| retry_delay: float = 2.0, | |
| ) -> str: | |
| """Calls the OpenAI API with retry logic. Returns raw text output.""" | |
| client = get_client() | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=temperature, | |
| ) | |
| content = response.choices[0].message.content | |
| return (content or "").strip() | |
| except Exception as e: | |
| last_error = e | |
| if attempt < max_retries: | |
| time.sleep(retry_delay * attempt) | |
| else: | |
| err_msg = str(last_error) | |
| raise RuntimeError(f"LLM call failed after {max_retries} attempts: {err_msg}") from last_error | |
| return "" | |
| def call_llm_vision( | |
| messages: list[dict], | |
| model: str = "gpt-4o-mini", | |
| temperature: float = 0.85, | |
| max_retries: int = 3, | |
| retry_delay: float = 2.0, | |
| ) -> str: | |
| """Calls the OpenAI Chat API with messages that may include image_url content parts. | |
| messages: list of message dicts, e.g. [{"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "https://..."}}]}]. | |
| Returns the assistant's text content.""" | |
| client = get_client() | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| ) | |
| content = response.choices[0].message.content | |
| return (content or "").strip() | |
| except Exception as e: | |
| last_error = e | |
| if attempt < max_retries: | |
| time.sleep(retry_delay * attempt) | |
| else: | |
| err_msg = str(last_error) | |
| raise RuntimeError(f"LLM call failed after {max_retries} attempts: {err_msg}") from last_error | |
| return "" | |