Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| from functools import lru_cache | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| FAKE_LABEL = "FAKE" | |
| REAL_LABEL = "REAL" | |
| FAKE_HINTS = ("fake", "deepfake", "ai", "generated", "synthetic", "manipulated") | |
| REAL_HINTS = ("real", "human", "authentic", "genuine", "original") | |
| def _resolve_default_path(env_var: str, primary: str, fallback: str) -> str: | |
| configured = os.getenv(env_var, "").strip() | |
| if configured: | |
| return configured | |
| if os.path.isfile(primary): | |
| return primary | |
| return fallback | |
| def normalize_scores(fake_score: float, real_score: float) -> Tuple[float, float]: | |
| fake = float(max(fake_score, 0.0)) | |
| real = float(max(real_score, 0.0)) | |
| total = fake + real | |
| if total <= 0: | |
| return 0.5, 0.5 | |
| return fake / total, real / total | |
| def _ensure_rgb_uint8(image: np.ndarray) -> np.ndarray: | |
| if image is None: | |
| raise ValueError("Image is required") | |
| arr = np.asarray(image) | |
| if arr.ndim == 2: | |
| arr = np.stack([arr] * 3, axis=-1) | |
| if arr.ndim != 3: | |
| raise ValueError(f"Expected image with 2 or 3 dimensions, got shape {arr.shape}") | |
| if arr.shape[-1] == 4: | |
| arr = arr[..., :3] | |
| if arr.shape[-1] != 3: | |
| raise ValueError(f"Expected 3 channels, got shape {arr.shape}") | |
| if arr.dtype != np.uint8: | |
| arr = np.clip(arr, 0, 255).astype(np.uint8) | |
| return arr | |
| def _to_png_bytes(image: np.ndarray) -> bytes: | |
| rgb = _ensure_rgb_uint8(image) | |
| pil_img = Image.fromarray(rgb).convert("RGB") | |
| buffer = io.BytesIO() | |
| pil_img.save(buffer, format="PNG") | |
| return buffer.getvalue() | |
| def _extract_state_dict(checkpoint: Any) -> Any: | |
| if not isinstance(checkpoint, dict): | |
| return checkpoint | |
| for key in ("state_dict", "model_state_dict", "model"): | |
| maybe = checkpoint.get(key) | |
| if isinstance(maybe, dict): | |
| return maybe | |
| return checkpoint | |
| def _predict_payload(fake_score: float, real_score: float) -> Dict[str, Any]: | |
| fake_score, real_score = normalize_scores(fake_score, real_score) | |
| label = FAKE_LABEL if fake_score >= real_score else REAL_LABEL | |
| confidence = fake_score if label == FAKE_LABEL else real_score | |
| return { | |
| "label": label, | |
| "confidence": confidence, | |
| "scores": { | |
| FAKE_LABEL: round(fake_score, 4), | |
| REAL_LABEL: round(real_score, 4), | |
| }, | |
| } | |
| def _load_torch_model(model_path: str, arch: str): | |
| import torch | |
| from timm import create_model | |
| from torchvision import transforms | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if isinstance(checkpoint, dict) and "arch" in checkpoint and not arch: | |
| arch = checkpoint["arch"] | |
| arch = arch or "efficientnet_b4" | |
| model = create_model(arch, pretrained=False, num_classes=2) | |
| state_dict = _extract_state_dict(checkpoint) | |
| try: | |
| model.load_state_dict(state_dict) | |
| except RuntimeError: | |
| # Common fix for DataParallel checkpoints with "module." prefixes. | |
| remapped = {} | |
| for key, value in state_dict.items(): | |
| new_key = key[7:] if key.startswith("module.") else key | |
| remapped[new_key] = value | |
| model.load_state_dict(remapped) | |
| model.to(device).eval() | |
| image_size = int(checkpoint.get("image_size", 380)) if isinstance(checkpoint, dict) else 380 | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(image_size + 20), | |
| transforms.CenterCrop(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| return model, transform, device | |
| def predict_local_torch(image: np.ndarray) -> Dict[str, Any]: | |
| import torch | |
| model_path = _resolve_default_path( | |
| "LOCAL_MODEL_PATH", | |
| "models/pytorch_model.pth", | |
| "models/reference/pytorch_model.pth", | |
| ) | |
| arch = os.getenv("MODEL_ARCH", "efficientnet_b4") | |
| if not os.path.isfile(model_path): | |
| raise FileNotFoundError( | |
| f"Missing local model: {model_path}. " | |
| "Train one with train.py or set LOCAL_MODEL_PATH to an existing checkpoint." | |
| ) | |
| model, transform, device = _load_torch_model(model_path, arch) | |
| rgb = _ensure_rgb_uint8(image) | |
| pil_img = Image.fromarray(rgb).convert("RGB") | |
| tensor = transform(pil_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy().tolist() | |
| fake_score = float(probs[0]) | |
| real_score = float(probs[1]) | |
| return _predict_payload(fake_score, real_score) | |
| def _load_onnx_session(onnx_path: str): | |
| import onnxruntime as ort | |
| providers = ["CPUExecutionProvider"] | |
| return ort.InferenceSession(onnx_path, providers=providers) | |
| def predict_local_onnx(image: np.ndarray) -> Dict[str, Any]: | |
| onnx_path = _resolve_default_path( | |
| "LOCAL_ONNX_PATH", | |
| "models/model.onnx", | |
| "models/reference/efficientnet.onnx", | |
| ) | |
| default_image_size = int(os.getenv("IMAGE_SIZE", "380")) | |
| if not os.path.isfile(onnx_path): | |
| raise FileNotFoundError( | |
| f"Missing ONNX model: {onnx_path}. Export one with export_onnx.py or set LOCAL_ONNX_PATH." | |
| ) | |
| from PIL import ImageOps | |
| session = _load_onnx_session(onnx_path) | |
| input_meta = session.get_inputs()[0] | |
| input_name = input_meta.name | |
| input_shape = list(input_meta.shape) | |
| is_nchw = len(input_shape) == 4 and input_shape[1] == 3 | |
| is_nhwc = len(input_shape) == 4 and input_shape[-1] == 3 | |
| if is_nchw: | |
| height = input_shape[2] if isinstance(input_shape[2], int) else default_image_size | |
| width = input_shape[3] if isinstance(input_shape[3], int) else default_image_size | |
| elif is_nhwc: | |
| height = input_shape[1] if isinstance(input_shape[1], int) else default_image_size | |
| width = input_shape[2] if isinstance(input_shape[2], int) else default_image_size | |
| else: | |
| raise RuntimeError(f"Unsupported ONNX input shape: {input_shape}") | |
| rgb = _ensure_rgb_uint8(image) | |
| pil_img = Image.fromarray(rgb).convert("RGB") | |
| resized = ImageOps.fit(pil_img, (width, height), method=Image.Resampling.BILINEAR) | |
| arr = np.asarray(resized).astype(np.float32) / 255.0 | |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| arr = (arr - mean) / std | |
| if is_nchw: | |
| arr = np.transpose(arr, (2, 0, 1))[None, ...] | |
| else: | |
| arr = arr[None, ...] | |
| outputs = session.run(None, {input_name: arr}) | |
| logits = np.asarray(outputs[0]) | |
| logits = logits[0] if logits.ndim > 1 else logits | |
| if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(np.sum(logits), 1.0, atol=1e-3): | |
| probs = logits | |
| else: | |
| logits = logits - np.max(logits) | |
| exps = np.exp(logits) | |
| probs = exps / np.sum(exps) | |
| fake_score = float(probs[0]) | |
| real_score = float(probs[1]) | |
| return _predict_payload(fake_score, real_score) | |
| def _flatten_hf_output(payload: Any) -> List[Dict[str, Any]]: | |
| if isinstance(payload, list): | |
| if payload and isinstance(payload[0], list): | |
| payload = payload[0] | |
| return [x for x in payload if isinstance(x, dict)] | |
| if isinstance(payload, dict): | |
| if "label" in payload and "score" in payload: | |
| return [payload] | |
| if "error" in payload: | |
| raise RuntimeError(str(payload["error"])) | |
| return [] | |
| def _scores_from_hf(entries: List[Dict[str, Any]]) -> Tuple[float, float]: | |
| fake = 0.0 | |
| real = 0.0 | |
| for entry in entries: | |
| label = str(entry.get("label", "")).lower().strip() | |
| score = float(entry.get("score", 0.0)) | |
| if any(h in label for h in FAKE_HINTS): | |
| fake = max(fake, score) | |
| if any(h in label for h in REAL_HINTS): | |
| real = max(real, score) | |
| if fake == 0.0 and real == 0.0 and entries: | |
| if len(entries) >= 2: | |
| fake = float(entries[0].get("score", 0.0)) | |
| real = float(entries[1].get("score", 0.0)) | |
| else: | |
| top = entries[0] | |
| top_score = float(top.get("score", 0.0)) | |
| top_label = str(top.get("label", "")).lower() | |
| if any(h in top_label for h in REAL_HINTS): | |
| real = top_score | |
| else: | |
| fake = top_score | |
| return normalize_scores(fake, real) | |
| def predict_hf_api(image: np.ndarray) -> Dict[str, Any]: | |
| token = os.getenv("HF_TOKEN", "").strip() | |
| endpoint = os.getenv("HF_INFERENCE_ENDPOINT", "").strip() | |
| model_id = os.getenv("DEEPFAKE_MODEL_ID", "").strip() | |
| if not token: | |
| raise RuntimeError("Missing HF_TOKEN secret.") | |
| if not endpoint and not model_id: | |
| raise RuntimeError("Set either HF_INFERENCE_ENDPOINT or DEEPFAKE_MODEL_ID.") | |
| url = endpoint or f"https://api-inference.huggingface.co/models/{model_id}" | |
| headers = {"Authorization": f"Bearer {token}"} | |
| response = requests.post(url, headers=headers, data=_to_png_bytes(image), timeout=180) | |
| if response.status_code >= 400: | |
| raise RuntimeError(f"HF API error {response.status_code}: {response.text[:240]}") | |
| payload = response.json() | |
| entries = _flatten_hf_output(payload) | |
| if not entries: | |
| raise RuntimeError(f"Unexpected response: {payload}") | |
| fake_score, real_score = _scores_from_hf(entries) | |
| return _predict_payload(fake_score, real_score) | |