from __future__ import annotations import json import logging import os import time from pathlib import Path from typing import Any, Dict, Optional, Tuple from ollama import Client, ResponseError from config import ( MAX_RETRIES, MODEL_NAME, NUM_CTX, OLLAMA_HOST, RETRY_SLEEP_SEC, TEMPERATURE, TOP_P, ) from final_judge import run_final_judge from prompt import ( SYSTEM_PROMPT, build_metric_prompt, build_presence_prompt, ) from reconciler import build_structured_summary, reconcile_metrics from schema import ( METRIC_KEYS, PRESENCE_SCHEMA, metric_schema, validate_scores, ) os.environ["NO_PROXY"] = "localhost,127.0.0.1" os.environ["no_proxy"] = "localhost,127.0.0.1" for key in [ "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy", ]: os.environ.pop(key, None) CLIENT = Client(host=OLLAMA_HOST, timeout=120.0) def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: raw = (text or "").strip() if not raw: return None, "Empty response" try: return json.loads(raw), None except json.JSONDecodeError: pass start = raw.find("{") end = raw.rfind("}") if start != -1 and end != -1 and end > start: fragment = raw[start:end + 1] try: return json.loads(fragment), None except json.JSONDecodeError as e: return None, f"JSON decode error after extraction: {e}" return None, "No JSON object found in response" def clamp_int_1_5(x: Any) -> int: try: x = int(round(float(x))) except Exception: return 3 return max(1, min(5, x)) def check_ollama_available() -> None: response = CLIENT.list() if "models" not in response: raise RuntimeError("Ollama list() did not return models") def check_model_present(model_name: str) -> None: response = CLIENT.list() names = [] for item in response.get("models", []): name = item.get("model") or item.get("name") if name: names.append(name) if not any(model_name == name or model_name in name for name in names): raise RuntimeError(f"Model '{model_name}' not found. Run: ollama pull {model_name}") def smoke_test_text() -> None: response = CLIENT.chat( model=MODEL_NAME, messages=[{"role": "user", "content": "Reply with exactly: OK"}], options={"temperature": 0.0, "top_p": 1.0, "num_ctx": 512}, ) content = response["message"]["content"].strip() logging.info(f"Smoke test response: {content}") def call_ollama( image_path: Path, user_prompt: str, schema: Dict[str, Any], temperature: float = TEMPERATURE, ) -> str: response = CLIENT.chat( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt, "images": [str(image_path)]}, ], format=schema, options={ "temperature": temperature, "top_p": TOP_P, "num_ctx": NUM_CTX, }, ) if "message" not in response or "content" not in response["message"]: raise ValueError(f"Unexpected Ollama SDK response format: {response}") return response["message"]["content"] def ask_json( image_path: Path, user_prompt: str, schema: Dict[str, Any], step_name: str, ) -> Tuple[Dict[str, Any], str]: filename = image_path.name last_error = "" for attempt in range(MAX_RETRIES + 1): try: step_started = time.time() logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_start") raw_response = call_ollama( image_path=image_path, user_prompt=user_prompt, schema=schema, temperature=0.0 if attempt > 0 else TEMPERATURE, ) elapsed = time.time() - step_started logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_done | sec={elapsed:.2f}") parsed, parse_error = safe_json_loads(raw_response) if parsed is None: last_error = f"{step_name}: parse error: {parse_error}" logging.warning(f"{filename} | attempt={attempt + 1} | {last_error}") logging.info(f"{filename} | raw response preview: {raw_response[:500]}") continue return parsed, raw_response except ResponseError as e: last_error = f"{step_name}: {e}" logging.exception(f"{filename} | attempt={attempt + 1} failed") if "503" in str(e): time.sleep(RETRY_SLEEP_SEC) except Exception as e: last_error = f"{step_name}: {e}" logging.exception(f"{filename} | attempt={attempt + 1} failed") time.sleep(2) raise RuntimeError(last_error or f"{step_name}: model failed after retries") def build_comment_from_label_and_summary(label: str, summary_text: str, rationale: str) -> str: base = rationale.strip().rstrip(".") if not base: if label == "good": return "The poster shows strong overall design control." if label == "bad": return "The poster shows weak overall design control." return "The poster shows mixed design quality." if len(base) > 170: base = base[:167].rstrip() + "..." return base + "." def get_llm_diagnostics(image_path: Path, img_feat: Dict[str, Any]) -> Dict[str, Any]: filename = image_path.name raw_trace: Dict[str, Any] = {} logging.info(f"{filename} | step=presence | start") presence_data, raw = ask_json( image_path=image_path, user_prompt=build_presence_prompt(filename), schema=PRESENCE_SCHEMA, step_name="presence", ) logging.info(f"{filename} | step=presence | done") raw_trace["presence"] = raw content_present = bool(presence_data["content_present"]) if not content_present: result = { "content_present": False, "hierarchy_strength": 1, "focal_clarity": 1, "typography_control": 1, "font_conflict": 1, "palette_control": 1, "color_conflict": 1, "visual_clutter": 1, "template_genericity": 1, "originality": 1, "ai_generated_score": 1, "technical_quality": 3, "visual_impact": 1, "confidence": 0.95, "comment": "No meaningful poster content detected.", "_final_label_prior": "bad", "_raw_response": json.dumps(raw_trace, ensure_ascii=False), } return result raw_metrics: Dict[str, int] = {} for metric_name in METRIC_KEYS: metric_data, raw = ask_json( image_path=image_path, user_prompt=build_metric_prompt(filename, metric_name), schema=metric_schema(metric_name), step_name=metric_name, ) raw_metrics[metric_name] = clamp_int_1_5(metric_data[metric_name]) raw_trace[metric_name] = raw reconciled_metrics, warnings = reconcile_metrics(raw_metrics, img_feat) summary_text = build_structured_summary(reconciled_metrics, warnings) final_judge = run_final_judge( client=CLIENT, model_name=MODEL_NAME, top_p=TOP_P, num_ctx=NUM_CTX, image_path=image_path, summary_text=summary_text, metrics=reconciled_metrics, system_prompt=SYSTEM_PROMPT, ) label = str(final_judge["label"]) confidence = float(final_judge["confidence"]) rationale = str(final_judge["rationale"]).strip() result = { "content_present": True, **reconciled_metrics, "confidence": confidence, "comment": build_comment_from_label_and_summary(label, summary_text, rationale), "_final_label_prior": label, "_raw_response": json.dumps( { "raw_trace": raw_trace, "raw_metrics": raw_metrics, "reconciled_metrics": reconciled_metrics, "warnings": warnings, "summary_text": summary_text, "final_judge": final_judge, }, ensure_ascii=False, ), } is_valid, validation_msg = validate_scores(result) if not is_valid: raise RuntimeError(f"Validation error: {validation_msg} | result={result}") return result