g_solution / app /services /classifier.py
minhvtt's picture
Update app/services/classifier.py
43794e1 verified
Raw
History Blame Contribute Delete
32 kB
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any
MODELS_DIR = Path(__file__).resolve().parents[2] / "models"
NSFW_CONFIG_PATH = MODELS_DIR / "config.json"
NSFW_WEIGHTS_PATH = MODELS_DIR / "model.safetensors"
NSFW_THRESHOLD = 0.75
def _parse_csv_env(raw_value: str, default_csv: str) -> list[str]:
value = (raw_value or "").strip() or default_csv
return [x.strip() for x in value.split(",") if x.strip()]
GAME_ZERO_SHOT_MODEL = os.getenv("GAME_ZERO_SHOT_MODEL", "google/siglip-so400m-patch14-384")
GAME_THRESHOLD = float(os.getenv("GAME_THRESHOLD", "0.65")) # If game_prob > 55% → GAME
# Game-focused labels for SigLIP text-image matching (optimized for game detection)
# Multiple game types to capture diverse game screenshots, final decision delegated to Qwen LLM
GAME_LABELS = [
# Roblox specific
"Roblox game with blocky 3D characters and Roblox UI",
"Roblox game lobby with avatar and robux currency display",
"Roblox game world with colorful blocky environment and players",
"Roblox obby or obstacle course game with platforms",
"Roblox roleplay game with characters and chat bubbles",
# Browser / online games (gamevui.vn style)
"a Vietnamese browser game website with interactive play controls and cartoon characters",
"an online flash or HTML5 game embedded in a webpage with score or play controls",
"a casual browser game with score counter and timer",
"a mobile-style game running in a browser with touch controls",
"a colorful 2D browser game with animated sprites and game UI",
"an online puzzle or matching game with colorful tiles",
"a drawing or coloring game for kids in a browser",
# General game UI signals
"a 3D game screen with health bar lives counter and score display",
"a video game with minimap inventory and player stats HUD",
"a video game with a big yellow PLAY button or START button on a colorful background",
"a video game lobby with 3D avatars, player list, and a START GAME button",
"a game over or victory screen with score results and rank",
"a video game character selection screen with interactive buttons and stats",
"a role playing game RPG with character stats, inventory, and quest log",
"a 3D action game with weapons crosshair and enemies",
"a strategy or tower defense game with map and units",
"a horror or dark themed game with spooky environment",
"a sports or racing game with speed meter and track",
"a fighting game with two characters and health bars",
"a game loading screen with progress bar and game logo",
]
CONTRAST_NOT_GAME_LABELS = [
# Browser Shell
"a web browser window with many tabs at the top",
"a chrome browser address bar showing a website URL",
"a browser window with extension icons and profile picture",
# Web Verification / CAPTCHA
"a website verification page with a progress bar and bot check",
"Cloudflare verify you are human page with loading animation",
"a white browser page with text saying verify you are not a bot",
"a security check page on a website with a loading bar",
# Shared Screen / Meet
"a screen being shared via Google Meet with a blue stop sharing button at the bottom",
"a full screen presentation with a meet.google.com sharing notification bar",
"a computer desktop being recorded or shared in a video meeting",
# Text-heavy educational/article pages (common false-positive vs browser game labels)
"a Vietnamese educational article webpage with long text paragraphs and lesson navigation sidebar",
"an online lesson page with school grade menu and literary analysis text content",
"a reading article website with dense paragraphs and no gameplay controls or score HUD",
]
NOT_GAME_LABELS = [
*CONTRAST_NOT_GAME_LABELS,
# Photo / video editing (FIX for your 3D editing case)
"Adobe Photoshop with layers panel and image canvas",
"photo editing software with color palette and tool panel",
"video editing timeline with clips and playhead",
"3D modeling software with mesh wireframe and viewport grid",
"Blender or Maya 3D modeling interface with object properties",
"Canva or design tool with drag and drop elements",
"image editing app with filters brightness and contrast sliders",
"a graphic design application with artboard and shapes",
"Microsoft PowerPoint editor with ribbon menu and slide thumbnails on the left",
"a presentation slide editor with many formatting tools and icons",
# Productivity
"Microsoft Excel spreadsheet with rows columns and cell data",
"Google Sheets or table with numerical data and formulas",
"Microsoft Word or Google Docs document with text paragraphs",
"PowerPoint or presentation slides with bullet points",
"a PDF viewer showing a document or report",
"a calendar or scheduling application with events",
"a project management board like Trello or Notion",
"a form or survey with input fields and checkboxes",
# Code / dev tools
"a code editor like VS Code with syntax highlighted programming code",
"a terminal or command line interface with text commands",
"a browser developer tools panel with HTML and CSS",
"a database management interface with tables and queries",
"a SQL editor with query results and database schema panels",
"a database admin tool like phpMyAdmin or Adminer",
"a database client with table list query editor and result grid",
"a database dashboard with records indexes and schema views",
"a dashboard with charts graphs and analytics data",
# Video Conferencing (FIX for Google Meet/Zoom)
"a Google Meet video call with a grid of participant faces",
"a Zoom or Microsoft Teams meeting with gallery view of people",
# Chat / media / gallery
"a chat application screen with message bubbles and shared images",
"a browser page showing multiple 3D render images in a gallery",
"an image gallery or collage page with multiple 3D renders or artwork",
"a browser page showing product photos, renders, or artwork thumbnails",
"a Google search results page with search bar and blue links",
"a web search page showing search results and snippets",
"a browser page with a Google search box and results list",
# Certificates / awards
"a certificate or diploma page with ornate border and completion text",
"an award certificate for a student completion or achievement",
"a certificate cover page with decorative border and title text",
"an online certificate template with signature and seal",
]
GAME_ZERO_SHOT_LABELS = GAME_LABELS + NOT_GAME_LABELS
GAME_ZERO_SHOT_GAME_LABEL_INDICES = tuple(range(len(GAME_LABELS)))
GAME_ZERO_SHOT_NOT_GAME_LABEL_INDICES = tuple(range(len(GAME_LABELS), len(GAME_ZERO_SHOT_LABELS)))
GAME_ZERO_SHOT_CLEAR_NOT_GAME_LABEL_INDICES = tuple(
range(len(GAME_LABELS), len(GAME_LABELS) + len(CONTRAST_NOT_GAME_LABELS))
)
GAME_ZERO_SHOT_BROWSER_GAME_LABEL_INDICES = {5, 6, 7, 8, 9, 10, 11}
GAME_ZERO_SHOT_GAME_THRESHOLD = GAME_THRESHOLD # Đồng nhất với GAME_THRESHOLD (0.65)
GAME_ZERO_SHOT_NOT_GAME_THRESHOLD = float(os.getenv("GAME_ZERO_SHOT_NOT_GAME_THRESHOLD", "0.30")) # contrast labels
GAME_ZERO_SHOT_AMBIGUOUS_THRESHOLD = float(os.getenv("GAME_ZERO_SHOT_AMBIGUOUS_THRESHOLD", "0.20")) # when scores too close
# Qwen model for final decision gate
QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen2-0.5B-Instruct")
QWEN_ENABLE = os.getenv("QWEN_ENABLE", "true").lower() == "true"
QWEN_GAME_DECISION_THRESHOLD = float(os.getenv("QWEN_GAME_DECISION_THRESHOLD", "0.65"))
_nsfw_runtime: dict[str, Any] | None = None
_nsfw_error: str | None = None
_game_zero_shot_runtime: dict[str, Any] | None = None
_game_zero_shot_error: str | None = None
_qwen_runtime: dict[str, Any] | None = None
_qwen_error: str | None = None
def _zero_shot_label_group(index: int) -> str:
return "game" if index < len(GAME_LABELS) else "not_game"
def _zero_shot_debug_label(index: int, description: str) -> str:
if index < len(GAME_LABELS):
prefix = f"game_{index}"
else:
prefix = f"not_game_{index - len(GAME_LABELS)}"
snippet = description[:36].rstrip()
return f"{prefix}:{snippet}"
def _should_accept_qwen_game_decision(verdict: str, confidence: float) -> bool:
return verdict == "game" and confidence >= QWEN_GAME_DECISION_THRESHOLD
def _parse_qwen_decision_response(response: str) -> dict[str, Any]:
cleaned = response.strip()
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
try:
payload = json.loads(cleaned[start : end + 1])
verdict = str(payload.get("verdict", "")).strip().lower()
confidence = float(payload.get("confidence", 0.0))
if verdict in {"game", "not_game"}:
return {"verdict": verdict, "confidence": confidence}
except Exception:
pass
upper = cleaned.upper()
if "NOT_GAME" in upper:
return {"verdict": "not_game", "confidence": 0.0}
if "GAME" in upper:
return {"verdict": "game", "confidence": 0.0}
return {"verdict": "uncertain", "confidence": 0.0}
def _is_gallery_like_not_game(description: str) -> bool:
text = description.lower()
return any(
token in text
for token in (
"gallery",
"collage",
"render",
"artwork",
"product photos",
"shared images",
"image gallery",
)
)
def _is_text_heavy_not_game(description: str) -> bool:
text = description.lower()
return any(
token in text
for token in (
"educational",
"lesson",
"article",
"literary analysis",
"dense paragraphs",
"reading article",
"text paragraphs",
)
)
logger = logging.getLogger(__name__)
SENSITIVE_KEYWORDS = {
"porn",
"sex",
"xxx",
"nsfw",
"adult",
"nude",
"erotic",
}
def warmup_game_classifiers() -> dict[str, str]:
status: dict[str, str] = {}
zero_runtime = _load_game_zero_shot_runtime()
if zero_runtime is not None:
status["vision"] = f"ready:{GAME_ZERO_SHOT_MODEL}"
else:
status["vision"] = f"unavailable:{_game_zero_shot_error or 'unknown'}"
if QWEN_ENABLE:
qwen_runtime = _load_qwen_runtime()
if qwen_runtime is not None:
status["decision"] = f"ready:{QWEN_MODEL}"
else:
status["decision"] = f"unavailable:{_qwen_error or 'unknown'}"
logger.info("game-classifier warmup status=%s", status)
return status
def classify_screenshot(file_path: str, filename: str, suspected_game: bool) -> tuple[bool, float, str]:
game_result = classify_game_with_ocr_llm(file_path, filename, suspected_game)
verdict = game_result["verdict"]
confidence = float(game_result["confidence"])
reason = str(game_result["reason"])
if verdict == "game":
return True, confidence, reason
if verdict == "uncertain":
return True, max(confidence, 0.51), "uncertain-review"
return False, confidence, reason
def classify_game_with_ocr_llm(file_path: str, filename: str, suspected_game: bool) -> dict[str, Any]:
ocr_text = ""
urls: list[str] = []
logger.info(
"game-detect start file=%s suspected_game=%s",
filename,
suspected_game,
)
zero_shot = _classify_game_with_zero_shot(file_path, suspected_game)
if zero_shot is not None:
logger.info(
"game-detect zero-shot verdict=%s confidence=%.2f reason=%s source=%s",
zero_shot["verdict"],
float(zero_shot["confidence"]),
zero_shot["reason"],
zero_shot.get("source", "zero-shot"),
)
return {
"verdict": str(zero_shot.get("verdict", "not_game")),
"confidence": float(zero_shot.get("confidence", 0.3)),
"reason": str(zero_shot.get("reason", "no_game_evidence")),
"ocr_text": ocr_text,
"urls": urls,
"source": str(zero_shot.get("source", "zero-shot")),
}
# Fallback heuristic path when zero-shot runtime is unavailable.
if suspected_game:
logger.info(
"game-detect fallback verdict=uncertain reason=signal-without-vision-runtime"
)
return {
"verdict": "uncertain",
"confidence": 0.55,
"reason": "signal-without-vision-runtime",
"ocr_text": ocr_text,
"urls": urls,
"source": "heuristic",
}
logger.info("game-detect fallback verdict=not_game reason=no-game-evidence")
return {
"verdict": "not_game",
"confidence": 0.2,
"reason": "no-game-evidence",
"ocr_text": ocr_text,
"urls": urls,
"source": "heuristic",
}
def filter_cacheable_game_urls(urls: list[str]) -> list[str]:
# OCR-based URL extraction is removed, so no cacheable URLs are produced.
return []
def _load_game_zero_shot_runtime() -> dict[str, Any] | None:
global _game_zero_shot_runtime
global _game_zero_shot_error
if _game_zero_shot_runtime is not None:
return _game_zero_shot_runtime
if _game_zero_shot_error is not None:
return None
try:
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
device = "cpu"
torch_dtype = torch.float32
# Load model with a smaller dtype on GPU so it fits in limited VRAM.
model = AutoModel.from_pretrained(
GAME_ZERO_SHOT_MODEL,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
)
model = model.to(device)
model.eval()
# Load processor with trust_remote_code
processor = AutoProcessor.from_pretrained(
GAME_ZERO_SHOT_MODEL,
trust_remote_code=True
)
_game_zero_shot_runtime = {
"model": model,
"processor": processor,
"torch": torch,
"Image": Image,
"device": device,
}
logger.info("game zero-shot runtime initialized model=%s device=%s", GAME_ZERO_SHOT_MODEL, device)
return _game_zero_shot_runtime
except Exception as exc:
_game_zero_shot_error = str(exc)
logger.warning("game zero-shot runtime init failed model=%s err=%s", GAME_ZERO_SHOT_MODEL, _game_zero_shot_error)
return None
def _load_qwen_runtime() -> dict[str, Any] | None:
"""Load Qwen 0.5B model for game decision gate."""
global _qwen_runtime
global _qwen_error
if not QWEN_ENABLE:
return None
if _qwen_runtime is not None:
return _qwen_runtime
if _qwen_error is not None:
return None
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cpu"
torch_dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL)
model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL,
dtype=torch_dtype,
device_map="cpu",
low_cpu_mem_usage=True,
)
model.eval()
_qwen_runtime = {
"model": model,
"tokenizer": tokenizer,
"torch": torch,
"device": device,
}
logger.info("qwen runtime initialized model=%s device=%s", QWEN_MODEL, device)
return _qwen_runtime
except Exception as exc:
_qwen_error = str(exc)
logger.warning("qwen runtime init failed model=%s err=%s", QWEN_MODEL, _qwen_error)
return None
def _classify_with_qwen_decision(file_path: str, siglip_breakdown: str, game_score: float) -> dict[str, Any]:
"""Use Qwen to make final game decision based on SigLIP breakdown."""
runtime = _load_qwen_runtime()
if runtime is None:
return {"verdict": "uncertain", "confidence": 0.0} # fallback to manual review
model = runtime["model"]
tokenizer = runtime["tokenizer"]
torch = runtime["torch"]
try:
# Build prompt with SigLIP visual analysis
prompt = f"""You are a strict game detection AI. Analyze the following image classification scores and determine if this is a video game screen.
SigLIP Classification Results:
{siglip_breakdown}
Max Game Score: {game_score:.2%}
Rules:
- Return GAME only if the screen clearly shows interactive gameplay, a game lobby, loading screen, score HUD, health bars, inventory, map, or other unmistakable game UI.
- If the screen is a chat app, browser page, document, code editor, image gallery, design board, or a page showing images/renders/artwork without gameplay controls or HUD, return NOT_GAME.
- Search engine homepages, search results pages, and query suggestion pages are not games.
- 3D images, 3D renders, and artwork alone are not a game.
- Certificates, diplomas, awards, completion pages, and certificate covers are not games.
- If unsure, return NOT_GAME.
Return JSON only in this format:
{{"verdict":"GAME|NOT_GAME","confidence":0-100}}
Answer:"""
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(runtime["device"])
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=40,
do_sample=False,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return _parse_qwen_decision_response(response)
except Exception as exc:
logger.warning("qwen decision failed err=%s", repr(exc))
return {"verdict": "uncertain", "confidence": 0.0}
def _classify_game_with_zero_shot(file_path: str, suspected_game: bool) -> dict[str, Any] | None:
runtime = _load_game_zero_shot_runtime()
if runtime is None:
return None
model = runtime["model"]
processor = runtime["processor"]
torch = runtime["torch"]
Image = runtime["Image"]
device = runtime["device"]
try:
import torch.nn.functional as F
# Load and prepare image
with Image.open(file_path) as img:
image = img.convert("RGB")
# Use detailed visual descriptions as labels (no template needed)
texts = GAME_ZERO_SHOT_LABELS
# Process inputs
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
# Use softmax to normalize scores (sum=1.0, labels compete)
probs = F.softmax(logits_per_image[0], dim=0).cpu().tolist()
except Exception as exc:
logger.warning("game zero-shot inference failed err=%s", repr(exc))
return None
if not probs or len(probs) == 0:
return None
normalized: list[dict[str, Any]] = []
for i, (full_desc, score) in enumerate(zip(GAME_ZERO_SHOT_LABELS, probs)):
normalized.append(
{
"index": i,
"label": _zero_shot_debug_label(i, full_desc),
"group": _zero_shot_label_group(i),
"description": full_desc,
"score": float(score),
}
)
game_rows = [row for row in normalized if row["group"] == "game"]
game_rows.sort(key=lambda x: float(x["score"]), reverse=True)
game_score = float(game_rows[0]["score"]) if game_rows else 0.0
game_scores_breakdown = [f"{row['label']}={row['score']:.3f}" for row in game_rows[:3]]
normalized.sort(key=lambda x: float(x["score"]), reverse=True)
top = normalized[0]
visual_summary = "; ".join(f"{row['label']}={row['score']:.3f}" for row in normalized)
game_summary = " + ".join(game_scores_breakdown)
# Ambiguous case detection
top_game_score = game_score
second_game_score = float(game_rows[1]["score"]) if len(game_rows) > 1 else 0.0
score_gap = top_game_score - second_game_score if top_game_score > 0 else 1.0
not_game_rows = [row for row in normalized if row["group"] == "not_game"]
not_game_score = float(not_game_rows[0]["score"]) if not_game_rows else 0.0
# Logic kích hoạt Qwen nhạy bén hơn:
# 1. Trường hợp bình thường: điểm thấp hoặc gap hẹp
# 2. Trường hợp đặc biệt: điểm game cao nhưng điểm design/not-game cũng "nhô" lên (0.25 - 0.4)
is_ambiguous = (
(game_score < GAME_ZERO_SHOT_GAME_THRESHOLD and
score_gap < GAME_ZERO_SHOT_AMBIGUOUS_THRESHOLD) or
(game_score >= 0.15 and game_score < GAME_ZERO_SHOT_GAME_THRESHOLD and
float(top["score"]) < GAME_ZERO_SHOT_NOT_GAME_THRESHOLD) or
(game_score >= GAME_ZERO_SHOT_GAME_THRESHOLD and not_game_score > 0.25) # Cạnh tranh cao
)
# Browser-game labels are broad; if the strongest non-game signal is an image/gallery/render page,
# prefer not_game unless the game score is overwhelming.
if (
top["index"] in GAME_ZERO_SHOT_BROWSER_GAME_LABEL_INDICES
and top_game_score < 0.85
and not_game_rows
and float(not_game_rows[0]["score"]) >= 0.10
and _is_gallery_like_not_game(str(not_game_rows[0]["description"]))
):
return {
"verdict": "not_game",
"confidence": float(not_game_rows[0]["score"]),
"reason": "gallery_or_render_conflict",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
# Browser-game labels are also confused by text-heavy educational/article websites.
# If a strong text-page non-game label competes with broad browser-game cues, prefer not_game.
if (
top["index"] in GAME_ZERO_SHOT_BROWSER_GAME_LABEL_INDICES
and top_game_score < 0.90
and not_game_rows
and float(not_game_rows[0]["score"]) >= 0.12
and _is_text_heavy_not_game(str(not_game_rows[0]["description"]))
):
return {
"verdict": "not_game",
"confidence": float(not_game_rows[0]["score"]),
"reason": "text_article_conflict",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
# Clear game: high confidence
if game_score >= GAME_ZERO_SHOT_GAME_THRESHOLD:
return {
"verdict": "game",
"confidence": game_score,
"reason": "game_ui_detected",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
"game_breakdown": game_summary,
}
# Clear non-game: strong contrast signal
if (
top["index"] in GAME_ZERO_SHOT_CLEAR_NOT_GAME_LABEL_INDICES
and float(top["score"]) >= GAME_ZERO_SHOT_NOT_GAME_THRESHOLD
):
return {
"verdict": "not_game",
"confidence": float(top["score"]),
"reason": "no_game_evidence",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
# Ambiguous case: close game scores, low overall scores
# Delegate to Qwen for final decision
if is_ambiguous:
logger.info(
"game-detect ambiguous case game_score=%.2f score_gap=%.3f top_label=%s top_score=%.2f, delegating to qwen",
game_score,
score_gap,
top["label"],
float(top["score"]),
)
qwen_result = _classify_with_qwen_decision(file_path, visual_summary, game_score)
qwen_verdict = qwen_result["verdict"]
qwen_confidence = float(qwen_result["confidence"])
if _should_accept_qwen_game_decision(qwen_verdict, qwen_confidence):
return {
"verdict": "game",
"confidence": max(game_score, qwen_confidence / 100.0 if qwen_confidence > 1.0 else qwen_confidence),
"reason": "qwen_game_decision",
"source": f"qwen:{QWEN_MODEL}+{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
"game_breakdown": game_summary,
}
elif qwen_verdict == "not_game":
return {
"verdict": "not_game",
"confidence": max(float(top["score"]), 0.3),
"reason": "qwen_not_game_decision",
"source": f"qwen:{QWEN_MODEL}+{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
else:
return {
"verdict": "not_game",
"confidence": max(float(top["score"]), 0.3),
"reason": "qwen_game_below_threshold",
"source": f"qwen:{QWEN_MODEL}+{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
# Final fallback if not ambiguous or Qwen uncertain
if suspected_game and game_score >= 0.25:
return {
"verdict": "uncertain",
"confidence": max(game_score, 0.55),
"reason": "conflicting_signals",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
return {
"verdict": "not_game",
"confidence": max(float(top["score"]), 0.3),
"reason": "no_game_evidence",
"source": f"zero-shot:{GAME_ZERO_SHOT_MODEL}",
"visual_summary": visual_summary,
}
def classify_sensitive_content(file_path: str, filename: str) -> tuple[bool, float, str]:
# First pass with image model inference; fallback to keyword signal only if unavailable.
model_result = _classify_sensitive_with_model(file_path)
if model_result is not None:
return model_result
name_text = f"{Path(file_path).name} {filename}".lower()
keyword_hit = any(word in name_text for word in SENSITIVE_KEYWORDS)
if keyword_hit:
return True, 0.65, "sensitive-keyword-fallback"
return False, 0.05, "no-sensitive-signal"
def _load_nsfw_runtime() -> dict[str, Any] | None:
global _nsfw_runtime
global _nsfw_error
if _nsfw_runtime is not None:
return _nsfw_runtime
if _nsfw_error is not None:
return None
try:
import timm
import torch
from PIL import Image
from safetensors.torch import load_file
if not NSFW_CONFIG_PATH.exists() or not NSFW_WEIGHTS_PATH.exists():
_nsfw_error = "missing-local-model-files"
return None
config_data = json.loads(NSFW_CONFIG_PATH.read_text(encoding="utf-8"))
architecture = str(config_data.get("architecture", "vit_tiny_patch16_384"))
num_classes = int(config_data.get("num_classes", 2))
label_names = [str(x).lower() for x in config_data.get("label_names", ["nsfw", "sfw"])]
pretrained_cfg = config_data.get("pretrained_cfg", {})
model = timm.create_model(architecture, pretrained=False, num_classes=num_classes).eval()
state_dict = load_file(str(NSFW_WEIGHTS_PATH), device="cpu")
model.load_state_dict(state_dict, strict=False)
# Use local config for preprocessing so inference does not depend on remote metadata.
model.pretrained_cfg = {**getattr(model, "pretrained_cfg", {}), **pretrained_cfg, "label_names": label_names}
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
_nsfw_runtime = {
"torch": torch,
"Image": Image,
"model": model,
"transforms": transforms,
"label_names": [str(x).lower() for x in label_names],
}
return _nsfw_runtime
except Exception as exc:
_nsfw_error = str(exc)
return None
def _classify_sensitive_with_model(file_path: str) -> tuple[bool, float, str] | None:
runtime = _load_nsfw_runtime()
if runtime is None:
return None
torch = runtime["torch"]
Image = runtime["Image"]
model = runtime["model"]
transforms = runtime["transforms"]
label_names = runtime["label_names"]
with Image.open(file_path) as img:
img = img.convert("RGB")
with torch.no_grad():
output = model(transforms(img).unsqueeze(0)).softmax(dim=-1).cpu()[0]
scores = [float(x) for x in output.tolist()]
nsfw_score = _extract_nsfw_score(scores, label_names)
return (nsfw_score >= NSFW_THRESHOLD, nsfw_score, "timm-marqo-nsfw")
def _extract_nsfw_score(scores: list[float], labels: list[str]) -> float:
for idx, label in enumerate(labels):
if "nsfw" in label:
return scores[idx]
if len(scores) >= 2:
return scores[1]
return scores[0] if scores else 0.0
def classify_screen(image, threshold: float = GAME_THRESHOLD) -> dict[str, Any]:
"""
Simplified game classification using pairwise softmax approach.
Compares max game label score vs max non-game label score using softmax
to determine probability of image being a game screen.
Args:
image: PIL Image object
threshold: Game probability threshold (default 0.55)
Returns:
Dictionary with:
- is_game: bool indicating if classified as game
- game_prob: float probability of being game (0-1)
- matched_game_label: str best matching game label
- matched_not_game_label: str best matching non-game label
"""
import torch
import torch.nn.functional as F
runtime = _load_game_zero_shot_runtime()
if runtime is None:
raise RuntimeError("Game classification runtime not available")
model = runtime["model"]
processor = runtime["processor"]
torch_lib = runtime["torch"]
device = runtime["device"]
try:
# Prepare combined labels: game labels + non-game labels
all_labels = GAME_LABELS + NOT_GAME_LABELS
n_game = len(GAME_LABELS)
# Ensure image is in RGB format
if hasattr(image, "convert"):
image = image.convert("RGB")
# Process inputs
inputs = processor(
text=all_labels,
images=image,
return_tensors="pt",
padding="max_length"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch_lib.no_grad():
logits = model(**inputs).logits_per_image[0]
# Extract max scores from game and non-game groups
game_logits = logits[:n_game]
not_game_logits = logits[n_game:]
game_score = game_logits.max().item()
not_game_score = not_game_logits.max().item()
# Get best matching labels
best_game_idx = game_logits.argmax().item()
best_not_game_idx = not_game_logits.argmax().item()
best_game_label = GAME_LABELS[best_game_idx]
best_not_game_label = NOT_GAME_LABELS[best_not_game_idx]
# Calculate game probability using softmax
pair = torch_lib.tensor([game_score, not_game_score])
prob_game = F.softmax(pair, dim=0)[0].item()
return {
"is_game": prob_game > threshold,
"game_prob": round(prob_game, 4),
"matched_game_label": best_game_label,
"matched_not_game_label": best_not_game_label,
}
except Exception as exc:
logger.warning("classify_screen failed err=%s", repr(exc))
raise