Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import logging | |
| import warnings | |
| import threading | |
| # ββ Silence noisy third-party loggers before any imports ββββββββββββββββββββββ | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| for _lib in ("transformers", "huggingface_hub", "huggingface_hub.utils._http", | |
| "urllib3", "httpx", "httpcore", "filelock", "fsspec"): | |
| logging.getLogger(_lib).setLevel(logging.CRITICAL) | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") | |
| logger = logging.getLogger("inference") | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, "models", "custom_cnn_standalone.pt") | |
| # ββ Devices βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| hf_device = 0 if torch.cuda.is_available() else -1 # HF pipeline device id | |
| logger.info(f"Using device: {device}") | |
| # ββ Custom CNN hyper-parameters βββββββββββββββββββββββββββββββββββββββββββββββ | |
| IMG_SIZE = 224 | |
| MEAN = [0.485, 0.456, 0.406] | |
| STD = [0.229, 0.224, 0.225] | |
| # THRESHOLD TUNING GUIDE (change this one value to fix wrong predictions): | |
| # 0.50 β balanced (default) | |
| # 0.60 β stricter: catches MORE fakes, may flag some real as fake | |
| # 0.65 β very strict: catches MOST fakes (use if fakes are slipping through) | |
| # 0.40 β lenient: fewer false alarms on real images | |
| OPTIMAL_THRESHOLD = 0.60 # raised from 0.5 β catches more AI fakes | |
| # WEIGHTED VOTING β CNN gets 2 votes, each HF model gets 1 vote | |
| # CNN is primary authority; HF models can only override when BOTH agree | |
| CNN_VOTES = 2 | |
| HF_VOTES = 1 | |
| # ββ Image pre-processing ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=MEAN, std=STD), | |
| ]) | |
| # ββ Load Custom CNN (TorchScript) βββββββββββββββββββββββββββββββββββββββββββββ | |
| if not os.path.exists(MODEL_PATH): | |
| _fallback = os.path.join(BASE_DIR, "custom_cnn_standalone.pt") | |
| if os.path.exists(_fallback): | |
| MODEL_PATH = _fallback | |
| else: | |
| raise FileNotFoundError( | |
| f"Model not found at: {MODEL_PATH} or {_fallback}") | |
| try: | |
| logger.info(f"Loading TorchScript model from {MODEL_PATH} ...") | |
| model = torch.jit.load(MODEL_PATH, map_location=device) | |
| model.eval() | |
| logger.info("Custom CNN loaded and set to eval mode.") | |
| except Exception as exc: | |
| logger.error(f"Failed to load model: {exc}") | |
| raise | |
| # ββ Image loader (with validation) βββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_image(image_source) -> Image.Image: | |
| """Accept a file path, raw bytes, or a PIL Image. Raise ValueError on failure.""" | |
| try: | |
| if isinstance(image_source, Image.Image): | |
| return image_source | |
| if isinstance(image_source, bytes): | |
| return Image.open(io.BytesIO(image_source)) | |
| if not os.path.exists(image_source): | |
| raise FileNotFoundError(f"Image not found: {image_source}") | |
| return Image.open(image_source) | |
| except Exception as exc: | |
| raise ValueError(f"Could not load image: {exc}") from exc | |
| # ββ Custom CNN inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _predict_probability(image: Image.Image) -> float: | |
| """Return sigmoid probability from the Custom CNN (higher = more 'real').""" | |
| tensor = transform(image.convert("RGB")).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| return torch.sigmoid(model(tensor)).item() | |
| # ββ HF model lazy-load (thread-safe sentinel pattern) ββββββββββββββββββββββββ | |
| # None = not yet attempted | |
| # False = permanently failed (never retry) | |
| # obj = loaded pipeline | |
| _model_hf1 = None # prithivMLmods/Deepfake-Detection-Exp-02-21 | |
| _model_hf2 = None # prithivMLmods/Deep-Fake-Detector-v2-Model | |
| _hf_lock = threading.Lock() | |
| def _load_hf_models() -> None: | |
| """Load both HF models exactly once under a mutex.""" | |
| global _model_hf1, _model_hf2 | |
| with _hf_lock: | |
| from transformers import pipeline as hf_pipeline | |
| if _model_hf1 is None: | |
| try: | |
| # HF model 1 loading log suppressed | |
| _model_hf1 = hf_pipeline( | |
| "image-classification", | |
| model="prithivMLmods/Deepfake-Detection-Exp-02-21", | |
| device=hf_device, | |
| ) | |
| # HF model 1 loaded log suppressed | |
| except Exception as exc: | |
| logger.warning(f"HF model 1 failed to load: {exc}") | |
| _model_hf1 = False | |
| if _model_hf2 is None: | |
| try: | |
| # HF model 2 loading log suppressed | |
| _model_hf2 = hf_pipeline( | |
| "image-classification", | |
| model="prithivMLmods/Deep-Fake-Detector-v2-Model", | |
| device=hf_device, | |
| ) | |
| # HF model 2 loaded log suppressed | |
| except Exception as exc: | |
| logger.warning(f"HF model 2 failed to load: {exc}") | |
| _model_hf2 = False | |
| def _predict_hf1(pil_image: Image.Image) -> str: | |
| """Deepfake-Detection-Exp-02-21 β labels: 'Deepfake' / 'Real' (correct).""" | |
| if _model_hf1 is None: | |
| _load_hf_models() | |
| if not _model_hf1: | |
| return "unknown" | |
| try: | |
| best = max(_model_hf1(pil_image), key=lambda x: x["score"]) | |
| return "ai" if best["label"].lower() == "deepfake" else "real" | |
| except Exception: | |
| return "unknown" | |
| def _predict_hf2(pil_image: Image.Image) -> str: | |
| """Deep-Fake-Detector-v2-Model β labels are INVERTED in HF config: | |
| 'Realism' actually means Deepfake, 'Deepfake' actually means Real.""" | |
| if _model_hf2 is None: | |
| _load_hf_models() | |
| if not _model_hf2: | |
| return "unknown" | |
| try: | |
| best = max(_model_hf2(pil_image), key=lambda x: x["score"]) | |
| # Inverted: 'Realism' label β AI fake | |
| return "ai" if best["label"].lower() == "realism" else "real" | |
| except Exception: | |
| return "unknown" | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_label(image_source) -> str: | |
| """Return 'real' or 'ai'.""" | |
| label, _ = predict_with_confidence(image_source) | |
| return label | |
| def predict_with_confidence(image_source) -> tuple: | |
| """ | |
| Return (label, confidence_percent) using weighted majority voting. | |
| Vote weights: | |
| Custom CNN β 2 votes (primary authority) | |
| HF model 1 β 1 vote | |
| HF model 2 β 1 vote | |
| This means HF models can ONLY override the CNN when: | |
| - BOTH HF models agree with each other, AND | |
| - Their combined 2 votes tie or beat CNN's 2 votes | |
| β In case of a tie, CNN label wins (it is the tiebreaker). | |
| Confidence = (winning_weighted_votes / total_weighted_votes) * 100 | |
| """ | |
| from concurrent.futures import ThreadPoolExecutor | |
| image = _load_image(image_source) | |
| # Run all 3 models concurrently for minimum latency | |
| with ThreadPoolExecutor(max_workers=3) as ex: | |
| fut_cnn = ex.submit(_predict_probability, image) | |
| fut_hf1 = ex.submit(_predict_hf1, image) | |
| fut_hf2 = ex.submit(_predict_hf2, image) | |
| probability = fut_cnn.result() | |
| hf1_label = fut_hf1.result() | |
| hf2_label = fut_hf2.result() | |
| # ββ CNN decision ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cnn_label = "real" if probability > OPTIMAL_THRESHOLD else "ai" | |
| cnn_conf = probability if cnn_label == "real" else (1.0 - probability) | |
| # ββ Weighted vote tally βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ai_votes = CNN_VOTES if cnn_label == "ai" else 0 | |
| real_votes = CNN_VOTES if cnn_label == "real" else 0 | |
| total_weight = CNN_VOTES | |
| if hf1_label != "unknown": | |
| if hf1_label == "ai": | |
| ai_votes += HF_VOTES | |
| else: | |
| real_votes += HF_VOTES | |
| total_weight += HF_VOTES | |
| if hf2_label != "unknown": | |
| if hf2_label == "ai": | |
| ai_votes += HF_VOTES | |
| else: | |
| real_votes += HF_VOTES | |
| total_weight += HF_VOTES | |
| # ββ Final decision (CNN wins ties) ββββββββββββββββββββββββββββββββββββββββ | |
| if ai_votes > real_votes: | |
| final_label = "ai" | |
| elif real_votes > ai_votes: | |
| final_label = "real" | |
| else: | |
| final_label = cnn_label # tie β CNN is tiebreaker | |
| # ββ Confidence ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| winning_votes = max(ai_votes, real_votes) | |
| vote_confidence = (winning_votes / total_weight) * 100 | |
| # Blend vote confidence with CNN raw probability for a smoother score | |
| cnn_confidence_pct = cnn_conf * 100 | |
| final_confidence = (vote_confidence * 0.6) + (cnn_confidence_pct * 0.4) | |
| logger.debug( | |
| f"CNN={cnn_label}({cnn_conf:.2f}) " | |
| f"Votes ai={ai_votes} real={real_votes}/{total_weight} " | |
| f"β {final_label} ({final_confidence:.1f}%)" | |
| ) | |
| return final_label, round(final_confidence, 2) | |