sushilideaclan01's picture
Restore changes from lost commit (excluding creativity_examples binaries for HF)
7067477
"""
LLM helpers: OpenAI client, call_llm with retry, extract_json from response.
"""
import json
import os
import re
import time
from openai import OpenAI
_client: OpenAI | None = None
def get_client() -> OpenAI:
global _client
if _client is None:
key = os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
_client = OpenAI(api_key=key)
return _client
def extract_json(raw: str) -> dict:
"""Extract JSON from LLM output, handling optional markdown code blocks, trailing commas, and missing commas."""
text = (raw or "").strip()
# Remove optional ```json ... ``` wrapper
match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
if match:
text = match.group(1).strip()
# Strip illegal trailing commas (e.g. ", }" or ", ]") so strict json.loads succeeds
text = re.sub(r",(\s*[}\]])", r"\1", text)
# Fix missing comma between array elements (common LLM error in long lists): "}\n {" -> "},\n {"
text = re.sub(r"\}(\s*)\{", r"}, \1{", text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
# Retry after fixing missing comma between "value"\n "next_key" (end of string, next key)
if "Expecting ',' delimiter" in str(e):
text = re.sub(r'"(\s*)\n(\s*)"([^"]+)":', r'",\n\2"\3":', text)
return json.loads(text)
raise
def call_llm(
prompt: str,
model: str = "gpt-4o-mini",
temperature: float = 0.85,
max_retries: int = 3,
retry_delay: float = 2.0,
response_format: dict | None = None,
) -> str:
"""Calls the OpenAI API with retry logic. Returns raw text output."""
client = get_client()
kwargs = {"model": model, "messages": [{"role": "user", "content": prompt}], "temperature": temperature}
if response_format is not None:
kwargs["response_format"] = response_format
for attempt in range(1, max_retries + 1):
try:
response = client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
return (content or "").strip()
except Exception as e:
last_error = e
if attempt < max_retries:
time.sleep(retry_delay * attempt)
else:
err_msg = str(last_error)
raise RuntimeError(f"LLM call failed after {max_retries} attempts: {err_msg}") from last_error
return ""
def call_llm_vision(
messages: list[dict],
model: str = "gpt-4o-mini",
temperature: float = 0.85,
max_retries: int = 3,
retry_delay: float = 2.0,
response_format: dict | None = None,
) -> str:
"""Calls the OpenAI Chat API with messages that may include image_url content parts.
messages: list of message dicts, e.g. [{"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "https://..."}}]}].
response_format: optional dict for structured output, e.g. {"type": "json_schema", "json_schema": {...}}.
Returns the assistant's text content."""
client = get_client()
kwargs = {"model": model, "messages": messages, "temperature": temperature}
if response_format is not None:
kwargs["response_format"] = response_format
for attempt in range(1, max_retries + 1):
try:
response = client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
return (content or "").strip()
except Exception as e:
last_error = e
if attempt < max_retries:
time.sleep(retry_delay * attempt)
else:
err_msg = str(last_error)
raise RuntimeError(f"LLM call failed after {max_retries} attempts: {err_msg}") from last_error
return ""