Poster_Analyzer / app /src /llm_diagnostics.py
DatsuNTOYOTA's picture
init app
ec4da21 verified
from __future__ import annotations
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from ollama import Client, ResponseError
from config import (
MAX_RETRIES,
MODEL_NAME,
NUM_CTX,
OLLAMA_HOST,
RETRY_SLEEP_SEC,
TEMPERATURE,
TOP_P,
)
from final_judge import run_final_judge
from prompt import (
SYSTEM_PROMPT,
build_metric_prompt,
build_presence_prompt,
)
from reconciler import build_structured_summary, reconcile_metrics
from schema import (
METRIC_KEYS,
PRESENCE_SCHEMA,
metric_schema,
validate_scores,
)
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
os.environ["no_proxy"] = "localhost,127.0.0.1"
for key in [
"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY",
"http_proxy", "https_proxy", "all_proxy",
]:
os.environ.pop(key, None)
CLIENT = Client(host=OLLAMA_HOST, timeout=120.0)
def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
raw = (text or "").strip()
if not raw:
return None, "Empty response"
try:
return json.loads(raw), None
except json.JSONDecodeError:
pass
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
fragment = raw[start:end + 1]
try:
return json.loads(fragment), None
except json.JSONDecodeError as e:
return None, f"JSON decode error after extraction: {e}"
return None, "No JSON object found in response"
def clamp_int_1_5(x: Any) -> int:
try:
x = int(round(float(x)))
except Exception:
return 3
return max(1, min(5, x))
def check_ollama_available() -> None:
response = CLIENT.list()
if "models" not in response:
raise RuntimeError("Ollama list() did not return models")
def check_model_present(model_name: str) -> None:
response = CLIENT.list()
names = []
for item in response.get("models", []):
name = item.get("model") or item.get("name")
if name:
names.append(name)
if not any(model_name == name or model_name in name for name in names):
raise RuntimeError(f"Model '{model_name}' not found. Run: ollama pull {model_name}")
def smoke_test_text() -> None:
response = CLIENT.chat(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Reply with exactly: OK"}],
options={"temperature": 0.0, "top_p": 1.0, "num_ctx": 512},
)
content = response["message"]["content"].strip()
logging.info(f"Smoke test response: {content}")
def call_ollama(
image_path: Path,
user_prompt: str,
schema: Dict[str, Any],
temperature: float = TEMPERATURE,
) -> str:
response = CLIENT.chat(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt, "images": [str(image_path)]},
],
format=schema,
options={
"temperature": temperature,
"top_p": TOP_P,
"num_ctx": NUM_CTX,
},
)
if "message" not in response or "content" not in response["message"]:
raise ValueError(f"Unexpected Ollama SDK response format: {response}")
return response["message"]["content"]
def ask_json(
image_path: Path,
user_prompt: str,
schema: Dict[str, Any],
step_name: str,
) -> Tuple[Dict[str, Any], str]:
filename = image_path.name
last_error = ""
for attempt in range(MAX_RETRIES + 1):
try:
step_started = time.time()
logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_start")
raw_response = call_ollama(
image_path=image_path,
user_prompt=user_prompt,
schema=schema,
temperature=0.0 if attempt > 0 else TEMPERATURE,
)
elapsed = time.time() - step_started
logging.info(f"{filename} | step={step_name} | attempt={attempt + 1} | request_done | sec={elapsed:.2f}")
parsed, parse_error = safe_json_loads(raw_response)
if parsed is None:
last_error = f"{step_name}: parse error: {parse_error}"
logging.warning(f"{filename} | attempt={attempt + 1} | {last_error}")
logging.info(f"{filename} | raw response preview: {raw_response[:500]}")
continue
return parsed, raw_response
except ResponseError as e:
last_error = f"{step_name}: {e}"
logging.exception(f"{filename} | attempt={attempt + 1} failed")
if "503" in str(e):
time.sleep(RETRY_SLEEP_SEC)
except Exception as e:
last_error = f"{step_name}: {e}"
logging.exception(f"{filename} | attempt={attempt + 1} failed")
time.sleep(2)
raise RuntimeError(last_error or f"{step_name}: model failed after retries")
def build_comment_from_label_and_summary(label: str, summary_text: str, rationale: str) -> str:
base = rationale.strip().rstrip(".")
if not base:
if label == "good":
return "The poster shows strong overall design control."
if label == "bad":
return "The poster shows weak overall design control."
return "The poster shows mixed design quality."
if len(base) > 170:
base = base[:167].rstrip() + "..."
return base + "."
def get_llm_diagnostics(image_path: Path, img_feat: Dict[str, Any]) -> Dict[str, Any]:
filename = image_path.name
raw_trace: Dict[str, Any] = {}
logging.info(f"{filename} | step=presence | start")
presence_data, raw = ask_json(
image_path=image_path,
user_prompt=build_presence_prompt(filename),
schema=PRESENCE_SCHEMA,
step_name="presence",
)
logging.info(f"{filename} | step=presence | done")
raw_trace["presence"] = raw
content_present = bool(presence_data["content_present"])
if not content_present:
result = {
"content_present": False,
"hierarchy_strength": 1,
"focal_clarity": 1,
"typography_control": 1,
"font_conflict": 1,
"palette_control": 1,
"color_conflict": 1,
"visual_clutter": 1,
"template_genericity": 1,
"originality": 1,
"ai_generated_score": 1,
"technical_quality": 3,
"visual_impact": 1,
"confidence": 0.95,
"comment": "No meaningful poster content detected.",
"_final_label_prior": "bad",
"_raw_response": json.dumps(raw_trace, ensure_ascii=False),
}
return result
raw_metrics: Dict[str, int] = {}
for metric_name in METRIC_KEYS:
metric_data, raw = ask_json(
image_path=image_path,
user_prompt=build_metric_prompt(filename, metric_name),
schema=metric_schema(metric_name),
step_name=metric_name,
)
raw_metrics[metric_name] = clamp_int_1_5(metric_data[metric_name])
raw_trace[metric_name] = raw
reconciled_metrics, warnings = reconcile_metrics(raw_metrics, img_feat)
summary_text = build_structured_summary(reconciled_metrics, warnings)
final_judge = run_final_judge(
client=CLIENT,
model_name=MODEL_NAME,
top_p=TOP_P,
num_ctx=NUM_CTX,
image_path=image_path,
summary_text=summary_text,
metrics=reconciled_metrics,
system_prompt=SYSTEM_PROMPT,
)
label = str(final_judge["label"])
confidence = float(final_judge["confidence"])
rationale = str(final_judge["rationale"]).strip()
result = {
"content_present": True,
**reconciled_metrics,
"confidence": confidence,
"comment": build_comment_from_label_and_summary(label, summary_text, rationale),
"_final_label_prior": label,
"_raw_response": json.dumps(
{
"raw_trace": raw_trace,
"raw_metrics": raw_metrics,
"reconciled_metrics": reconciled_metrics,
"warnings": warnings,
"summary_text": summary_text,
"final_judge": final_judge,
},
ensure_ascii=False,
),
}
is_valid, validation_msg = validate_scores(result)
if not is_valid:
raise RuntimeError(f"Validation error: {validation_msg} | result={result}")
return result