Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Tuple | |
| from ollama import Client, ResponseError | |
| from config import ( | |
| MAX_RETRIES, | |
| MODEL_NAME, | |
| NUM_CTX, | |
| OLLAMA_HOST, | |
| RETRY_SLEEP_SEC, | |
| TEMPERATURE, | |
| TOP_P, | |
| ) | |
| from final_judge import run_final_judge | |
| from prompt import ( | |
| SYSTEM_PROMPT, | |
| build_metric_prompt, | |
| build_presence_prompt, | |
| ) | |
| from reconciler import build_structured_summary, reconcile_metrics | |
| from schema import ( | |
| METRIC_KEYS, | |
| PRESENCE_SCHEMA, | |
| metric_schema, | |
| validate_scores, | |
| ) | |
| os.environ["NO_PROXY"] = "localhost,127.0.0.1" | |
| os.environ["no_proxy"] = "localhost,127.0.0.1" | |
| for key in [ | |
| "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", | |
| "http_proxy", "https_proxy", "all_proxy", | |
| ]: | |
| os.environ.pop(key, None) | |
| CLIENT = Client(host=OLLAMA_HOST, timeout=120.0) | |
| def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: | |
| raw = (text or "").strip() | |
| if not raw: | |
| return None, "Empty response" | |
| try: | |
| return json.loads(raw), None | |
| except json.JSONDecodeError: | |
| pass | |
| start = raw.find("{") | |
| end = raw.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| fragment = raw[start:end + 1] | |
| try: | |
| return json.loads(fragment), None | |
| except json.JSONDecodeError as e: | |
| return None, f"JSON decode error after extraction: {e}" | |
| return None, "No JSON object found in response" | |
| def clamp_int_1_5(x: Any) -> int: | |
| try: | |
| x = int(round(float(x))) | |
| except Exception: | |
| return 3 | |
| return max(1, min(5, x)) | |
| def check_ollama_available() -> None: | |
| response = CLIENT.list() | |
| if "models" not in response: | |
| raise RuntimeError("Ollama list() did not return models") | |
| def check_model_present(model_name: str) -> None: | |
| response = CLIENT.list() | |
| names = [] | |
| for item in response.get("models", []): | |
| name = item.get("model") or item.get("name") | |
| if name: | |
| names.append(name) | |
| if not any(model_name == name or model_name in name for name in names): | |
| raise RuntimeError(f"Model '{model_name}' not found. Run: ollama pull {model_name}") | |
| def smoke_test_text() -> None: | |
| response = CLIENT.chat( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": "Reply with exactly: OK"}], | |
| options={"temperature": 0.0, "top_p": 1.0, "num_ctx": 512}, | |
| ) | |
| content = response["message"]["content"].strip() | |
| logging.info(f"Smoke test response: {content}") | |
| def call_ollama( | |
| image_path: Path, | |
| user_prompt: str, | |
| schema: Dict[str, Any], | |
| temperature: float = TEMPERATURE, | |
| ) -> str: | |
| response = CLIENT.chat( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt, "images": [str(image_path)]}, | |
| ], | |
| format=schema, | |
| options={ | |
| "temperature": temperature, | |
| "top_p": TOP_P, | |
| "num_ctx": NUM_CTX, | |
| }, | |
| ) | |
| if "message" not in response or "content" not in response["message"]: | |
| raise ValueError(f"Unexpected Ollama SDK response format: {response}") | |
| return response["message"]["content"] | |
| def ask_json( | |
| image_path: Path, | |
| user_prompt: str, | |
| schema: Dict[str, Any], | |
| step_name: str, | |
| ) -> Tuple[Dict[str, Any], str]: | |
| filename = image_path.name | |
| last_error = "" | |
| for attempt in range(MAX_RETRIES + 1): | |
| try: | |
| step_started = time.time() | |
| logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_start") | |
| raw_response = call_ollama( | |
| image_path=image_path, | |
| user_prompt=user_prompt, | |
| schema=schema, | |
| temperature=0.0 if attempt > 0 else TEMPERATURE, | |
| ) | |
| elapsed = time.time() - step_started | |
| logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_done | sec={elapsed:.2f}") | |
| parsed, parse_error = safe_json_loads(raw_response) | |
| if parsed is None: | |
| last_error = f"{step_name}: parse error: {parse_error}" | |
| logging.warning(f"{filename} | attempt={attempt + 1} | {last_error}") | |
| logging.info(f"{filename} | raw response preview: {raw_response[:500]}") | |
| continue | |
| return parsed, raw_response | |
| except ResponseError as e: | |
| last_error = f"{step_name}: {e}" | |
| logging.exception(f"{filename} | attempt={attempt + 1} failed") | |
| if "503" in str(e): | |
| time.sleep(RETRY_SLEEP_SEC) | |
| except Exception as e: | |
| last_error = f"{step_name}: {e}" | |
| logging.exception(f"{filename} | attempt={attempt + 1} failed") | |
| time.sleep(2) | |
| raise RuntimeError(last_error or f"{step_name}: model failed after retries") | |
| def build_comment_from_label_and_summary(label: str, summary_text: str, rationale: str) -> str: | |
| base = rationale.strip().rstrip(".") | |
| if not base: | |
| if label == "good": | |
| return "The poster shows strong overall design control." | |
| if label == "bad": | |
| return "The poster shows weak overall design control." | |
| return "The poster shows mixed design quality." | |
| if len(base) > 170: | |
| base = base[:167].rstrip() + "..." | |
| return base + "." | |
| def get_llm_diagnostics(image_path: Path, img_feat: Dict[str, Any]) -> Dict[str, Any]: | |
| filename = image_path.name | |
| raw_trace: Dict[str, Any] = {} | |
| logging.info(f"{filename} | step=presence | start") | |
| presence_data, raw = ask_json( | |
| image_path=image_path, | |
| user_prompt=build_presence_prompt(filename), | |
| schema=PRESENCE_SCHEMA, | |
| step_name="presence", | |
| ) | |
| logging.info(f"{filename} | step=presence | done") | |
| raw_trace["presence"] = raw | |
| content_present = bool(presence_data["content_present"]) | |
| if not content_present: | |
| result = { | |
| "content_present": False, | |
| "hierarchy_strength": 1, | |
| "focal_clarity": 1, | |
| "typography_control": 1, | |
| "font_conflict": 1, | |
| "palette_control": 1, | |
| "color_conflict": 1, | |
| "visual_clutter": 1, | |
| "template_genericity": 1, | |
| "originality": 1, | |
| "ai_generated_score": 1, | |
| "technical_quality": 3, | |
| "visual_impact": 1, | |
| "confidence": 0.95, | |
| "comment": "No meaningful poster content detected.", | |
| "_final_label_prior": "bad", | |
| "_raw_response": json.dumps(raw_trace, ensure_ascii=False), | |
| } | |
| return result | |
| raw_metrics: Dict[str, int] = {} | |
| for metric_name in METRIC_KEYS: | |
| metric_data, raw = ask_json( | |
| image_path=image_path, | |
| user_prompt=build_metric_prompt(filename, metric_name), | |
| schema=metric_schema(metric_name), | |
| step_name=metric_name, | |
| ) | |
| raw_metrics[metric_name] = clamp_int_1_5(metric_data[metric_name]) | |
| raw_trace[metric_name] = raw | |
| reconciled_metrics, warnings = reconcile_metrics(raw_metrics, img_feat) | |
| summary_text = build_structured_summary(reconciled_metrics, warnings) | |
| final_judge = run_final_judge( | |
| client=CLIENT, | |
| model_name=MODEL_NAME, | |
| top_p=TOP_P, | |
| num_ctx=NUM_CTX, | |
| image_path=image_path, | |
| summary_text=summary_text, | |
| metrics=reconciled_metrics, | |
| system_prompt=SYSTEM_PROMPT, | |
| ) | |
| label = str(final_judge["label"]) | |
| confidence = float(final_judge["confidence"]) | |
| rationale = str(final_judge["rationale"]).strip() | |
| result = { | |
| "content_present": True, | |
| **reconciled_metrics, | |
| "confidence": confidence, | |
| "comment": build_comment_from_label_and_summary(label, summary_text, rationale), | |
| "_final_label_prior": label, | |
| "_raw_response": json.dumps( | |
| { | |
| "raw_trace": raw_trace, | |
| "raw_metrics": raw_metrics, | |
| "reconciled_metrics": reconciled_metrics, | |
| "warnings": warnings, | |
| "summary_text": summary_text, | |
| "final_judge": final_judge, | |
| }, | |
| ensure_ascii=False, | |
| ), | |
| } | |
| is_valid, validation_msg = validate_scores(result) | |
| if not is_valid: | |
| raise RuntimeError(f"Validation error: {validation_msg} | result={result}") | |
| return result |