| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision import transforms, models |
| from PIL import Image |
| import numpy as np |
| import pandas as pd |
| import requests |
| from io import BytesIO |
| import tempfile |
| import random |
| import base64 |
| import shutil |
| from typing import Any, Dict, List, Optional, Tuple |
| from collections import deque |
| from datetime import datetime, timezone |
|
|
| import cv2 |
| import yt_dlp |
|
|
| from fastapi import FastAPI, File, Form, Request, UploadFile |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| detector = None |
| current_model_path = None |
| history = deque(maxlen=10) |
|
|
| class AIDetectorModel(nn.Module): |
| def __init__(self, num_classes=2, dropout_prob=0.3, load_pretrained=True): |
| super(AIDetectorModel, self).__init__() |
| |
| |
| if load_pretrained: |
| print("Loading RegNet with pre-trained weights (large download)...") |
| self.backbone = models.regnet_y_16gf(weights=models.RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1) |
| else: |
| print("Creating RegNet architecture without pre-trained weights...") |
| self.backbone = models.regnet_y_16gf(weights=None) |
| |
| |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
| |
| |
| if hasattr(self.backbone, 'trunk_output') and hasattr(self.backbone.trunk_output, 'block4'): |
| self.backbone.trunk_output.block4.requires_grad_(True) |
| |
| |
| self.backbone.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) |
| |
| |
| num_ftrs = self.backbone.fc.in_features |
| |
| |
| self.backbone.fc = nn.Sequential( |
| nn.Linear(num_ftrs, 2048), |
| nn.SiLU(), |
| nn.Dropout(dropout_prob), |
| nn.Linear(2048, 1024), |
| nn.SiLU(), |
| nn.Dropout(dropout_prob), |
| nn.Linear(1024, 512), |
| nn.SiLU(), |
| nn.Dropout(dropout_prob), |
| nn.Linear(512, num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.backbone(x) |
|
|
| def analyze_checkpoint(checkpoint_path): |
| """Analyze checkpoint to determine if backbone weights are included""" |
| try: |
| print(f"Analyzing checkpoint: {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
| |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| print("Found structured checkpoint with model_state_dict") |
| else: |
| state_dict = checkpoint |
| print("Found direct state_dict checkpoint") |
| |
| |
| all_keys = list(state_dict.keys()) |
| backbone_params = [k for k in all_keys if k.startswith('backbone.')] |
| backbone_conv_params = [k for k in backbone_params if 'conv' in k] |
| backbone_block_params = [k for k in backbone_params if 'block' in k] |
| classifier_params = [k for k in all_keys if 'fc' in k] |
| |
| total_params = len(all_keys) |
| |
| print("Parameter analysis:") |
| print(f" • Total parameters: {total_params}") |
| print(f" • Backbone parameters: {len(backbone_params)}") |
| print(f" • Backbone conv layers: {len(backbone_conv_params)}") |
| print(f" • Backbone blocks: {len(backbone_block_params)}") |
| print(f" • Classifier parameters: {len(classifier_params)}") |
| |
| |
| |
| has_full_backbone = len(backbone_params) > 100 and len(backbone_conv_params) > 10 |
| |
| if has_full_backbone: |
| print("Complete model detected - backbone weights included.") |
| print("Will skip RegNet download for faster loading.") |
| else: |
| print("Incomplete model detected - backbone weights missing.") |
| print("Will download RegNet pre-trained weights.") |
| |
| return has_full_backbone, checkpoint |
| |
| except Exception as e: |
| print(f"Error analyzing checkpoint: {e}") |
| print("Falling back to downloading RegNet weights.") |
| return False, None |
|
|
| class AIImageDetector: |
| def __init__(self, model_path, device=None): |
| """Initialize AI Image Detector with smart loading""" |
| self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Initializing AI Image Detector on {self.device}") |
| |
| |
| has_backbone, checkpoint = analyze_checkpoint(model_path) |
| |
| |
| self.model = self._load_model(model_path, has_backbone, checkpoint) |
| self.model.eval() |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| print("Model loaded and ready for inference.") |
| |
| def _load_model(self, model_path, has_backbone, checkpoint): |
| """Load model with optimized backbone loading""" |
| try: |
| |
| if checkpoint is None: |
| checkpoint = torch.load(model_path, map_location=self.device) |
| |
| |
| load_pretrained = not has_backbone |
| |
| |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| |
| model = AIDetectorModel( |
| num_classes=checkpoint.get('num_classes', 2), |
| dropout_prob=checkpoint.get('dropout_prob', 0.3), |
| load_pretrained=load_pretrained |
| ) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| |
| if 'epoch' in checkpoint: |
| print(f"Loaded model from epoch {checkpoint['epoch']}") |
| if 'val_loss' in checkpoint: |
| print(f"Validation loss: {checkpoint['val_loss']:.4f}") |
| |
| else: |
| |
| model = AIDetectorModel(load_pretrained=load_pretrained) |
| model.load_state_dict(checkpoint) |
| |
| return model.to(self.device) |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| print("Attempting fallback loading...") |
| |
| |
| try: |
| model = AIDetectorModel(load_pretrained=True) |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| model.load_state_dict(checkpoint) |
| return model.to(self.device) |
| except Exception as fallback_error: |
| print(f"Fallback loading also failed: {fallback_error}") |
| raise e |
| |
| def preprocess_image(self, image): |
| """Preprocess image for model input""" |
| try: |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| tensor = self.transform(image).unsqueeze(0) |
| return tensor.to(self.device) |
| |
| except Exception as e: |
| print(f"Error preprocessing image: {e}") |
| raise e |
| |
| def predict(self, image): |
| """ |
| Predict if image is real or AI-generated |
| |
| Returns: |
| tuple: (prediction, confidence, probabilities) |
| """ |
| try: |
| |
| input_tensor = self.preprocess_image(image) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(input_tensor) |
| probabilities = F.softmax(outputs, dim=1) |
| confidence, predicted = torch.max(probabilities, 1) |
| |
| |
| probs = probabilities.cpu().numpy()[0] |
| pred_class = predicted.cpu().item() |
| conf_score = confidence.cpu().item() |
| |
| |
| class_names = ['REAL', 'FAKE'] |
| prediction = class_names[pred_class] |
| |
| return prediction, conf_score, probs |
| |
| except Exception as e: |
| print(f"Error during prediction: {e}") |
| raise e |
| |
| def predict_from_url(self, url): |
| """Download image from URL and make prediction""" |
| try: |
| |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} |
| response = requests.get(url, timeout=15, headers=headers) |
| response.raise_for_status() |
| |
| |
| image = Image.open(BytesIO(response.content)) |
| |
| |
| prediction, confidence, probabilities = self.predict(image) |
| |
| return prediction, confidence, probabilities, image |
| |
| except requests.exceptions.RequestException as e: |
| print(f"Error downloading image: {e}") |
| raise Exception(f"Failed to download image from URL: {str(e)}") |
| except Exception as e: |
| print(f"Error processing image from URL: {e}") |
| raise e |
|
|
| def get_available_models(): |
| """Get list of available model files""" |
| |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| exts = (".pth", ".pt") |
| discovered = [] |
| try: |
| for name in os.listdir(base_dir): |
| if name.lower().endswith(exts): |
| discovered.append(name) |
| except Exception: |
| discovered = [] |
|
|
| |
| fallback = ["best_ai_detector.pth", "best_ai_detector_new.pth"] |
|
|
| available_models = sorted(set(discovered)) if discovered else [m for m in fallback] |
| if discovered: |
| print(f"Found {len(available_models)} model file(s): {available_models}") |
| else: |
| print("No model files found alongside app.py, showing default options") |
|
|
| return available_models |
|
|
| def load_detector(model_path=None): |
| """Load the detector model with specified path""" |
| global detector, current_model_path |
| |
| try: |
| |
| if model_path is None: |
| model_path = 'best_ai_detector.pth' |
|
|
| |
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| resolved_path = model_path |
| if not os.path.isabs(resolved_path): |
| candidate = os.path.join(base_dir, resolved_path) |
| if os.path.exists(candidate): |
| resolved_path = candidate |
| |
| |
| if detector is not None and current_model_path == resolved_path: |
| return f"Model '{model_path}' is already loaded and ready." |
| |
| |
| if not os.path.exists(resolved_path): |
| available_models = get_available_models() |
| error_msg = f"Model file '{model_path}' not found.\n\n" |
| error_msg += f"Available models in current directory:\n" |
| for model in available_models: |
| exists = "OK" if os.path.exists(model) else "MISSING" |
| error_msg += f" {exists} {model}\n" |
| error_msg += f"\nPlease ensure your trained model file is uploaded to the Space." |
| return error_msg |
| |
| print(f"Loading model: {resolved_path}") |
| |
| |
| detector = AIImageDetector(resolved_path) |
| current_model_path = resolved_path |
| |
| return f"Model '{model_path}' loaded successfully.\nReady for image analysis." |
| |
| except Exception as e: |
| error_msg = f"Error loading model '{model_path}': {str(e)}\n\n" |
| error_msg += "This might be due to:\n" |
| error_msg += "• Incompatible model file\n" |
| error_msg += "• Corrupted checkpoint\n" |
| error_msg += "• Missing dependencies\n" |
| return error_msg |
|
|
| def _now_iso() -> str: |
| return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") |
|
|
|
|
| def _record_history(entry: Dict[str, Any]) -> None: |
| try: |
| history.appendleft(entry) |
| except Exception: |
| pass |
|
|
|
|
| def _get_selected_model_path(model_path: Optional[str] = None) -> str: |
| |
| if model_path and str(model_path).strip(): |
| return str(model_path).strip() |
| if current_model_path and str(current_model_path).strip(): |
| return str(current_model_path).strip() |
| return "best_ai_detector.pth" |
|
|
|
|
| def ensure_model_loaded(model_path: Optional[str] = None) -> str: |
| return _ensure_detector_loaded(_get_selected_model_path(model_path)) |
|
|
| def predict_image(image): |
| """Handle image upload and prediction""" |
| global detector |
| |
| |
| if detector is None: |
| return "Please load a model first using the dropdown and 'Load model' button.", None, None, None |
| |
| if image is None: |
| return "Please upload an image first.", None, None, None |
| |
| try: |
| |
| prediction, confidence, probabilities = detector.predict(image) |
| |
| |
| result_text = f"**Prediction:** {prediction}\n\n" |
| result_text += f"**Confidence:** {confidence:.1%}\n\n" |
| result_text += "**Detailed probabilities:**\n" |
| result_text += f"- REAL (human-made): {probabilities[0]:.1%}\n" |
| result_text += f"- FAKE (AI-generated): {probabilities[1]:.1%}\n\n" |
| |
| |
| if confidence > 0.8: |
| result_text += "**High confidence prediction**" |
| elif confidence > 0.6: |
| result_text += "**Moderate confidence prediction**" |
| else: |
| result_text += "**Low confidence - uncertain prediction**" |
| |
| |
| prob_df = pd.DataFrame({ |
| 'Category': ['REAL (Human)', 'FAKE (AI)'], |
| 'Probability': [float(probabilities[0]), float(probabilities[1])] |
| }) |
| |
| |
| status = f"{prediction} - {confidence:.1%} confidence" |
| |
| return result_text, prob_df, prediction, status |
| |
| except Exception as e: |
| error_msg = f"Prediction failed: {str(e)}" |
| print(error_msg) |
| return error_msg, None, None, None |
|
|
| def predict_from_url(url): |
| """Handle URL input and prediction""" |
| global detector |
| |
| |
| if detector is None: |
| return "Please load a model first using the dropdown and 'Load model' button.", None, None, None, None |
| |
| if not url or not url.strip(): |
| return "Please enter a valid image URL.", None, None, None, None |
| |
| try: |
| |
| prediction, confidence, probabilities, image = detector.predict_from_url(url.strip()) |
| |
| |
| result_text = f"**Prediction:** {prediction}\n\n" |
| result_text += f"**Confidence:** {confidence:.1%}\n\n" |
| result_text += "**Detailed probabilities:**\n" |
| result_text += f"- REAL (human-made): {probabilities[0]:.1%}\n" |
| result_text += f"- FAKE (AI-generated): {probabilities[1]:.1%}\n\n" |
| |
| |
| if confidence > 0.8: |
| result_text += "**High confidence prediction**" |
| elif confidence > 0.6: |
| result_text += "**Moderate confidence prediction**" |
| else: |
| result_text += "**Low confidence - uncertain prediction**" |
| |
| |
| prob_df = pd.DataFrame({ |
| 'Category': ['REAL (Human)', 'FAKE (AI)'], |
| 'Probability': [float(probabilities[0]), float(probabilities[1])] |
| }) |
| |
| |
| status = f"{prediction} - {confidence:.1%} confidence" |
| |
| return result_text, prob_df, prediction, status, image |
| |
| except Exception as e: |
| error_msg = f"URL processing failed: {str(e)}" |
| print(error_msg) |
| return error_msg, None, None, None, None |
|
|
| def _format_prediction_markdown(prediction: str, confidence: float, probabilities: np.ndarray) -> str: |
| result_text = f"**Prediction:** {prediction}\n\n" |
| result_text += f"**Confidence:** {confidence:.1%}\n\n" |
| result_text += "**Detailed probabilities:**\n" |
| result_text += f"- REAL (human-made): {probabilities[0]:.1%}\n" |
| result_text += f"- FAKE (AI-generated): {probabilities[1]:.1%}\n\n" |
|
|
| if confidence > 0.8: |
| result_text += "**High confidence prediction**" |
| elif confidence > 0.6: |
| result_text += "**Moderate confidence prediction**" |
| else: |
| result_text += "**Low confidence - uncertain prediction**" |
| return result_text |
|
|
|
|
| def _quick_status(prediction: str, confidence: float) -> str: |
| return f"{prediction} - {confidence:.1%} confidence" |
|
|
|
|
| def _ensure_detector_loaded(model_path: Optional[str]) -> str: |
| msg = load_detector(model_path) |
| return msg |
|
|
|
|
| def _analyze_pil_image(image: Image.Image) -> Dict[str, Any]: |
| global detector |
| if detector is None: |
| raise RuntimeError("Detector is not loaded. Load a model first.") |
|
|
| prediction, confidence, probabilities = detector.predict(image) |
| return { |
| "prediction": prediction, |
| "confidence": float(confidence), |
| "probabilities": [float(probabilities[0]), float(probabilities[1])], |
| "markdown": _format_prediction_markdown(prediction, float(confidence), probabilities), |
| "quick": _quick_status(prediction, float(confidence)), |
| "model": os.path.basename(current_model_path) if current_model_path else None, |
| } |
|
|
|
|
| def _download_to_tempfile(url: str, suffix: str, max_bytes: int = 200 * 1024 * 1024) -> str: |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} |
| with requests.get(url, timeout=30, headers=headers, stream=True) as r: |
| r.raise_for_status() |
| total = 0 |
| fd, path = tempfile.mkstemp(suffix=suffix) |
| os.close(fd) |
| try: |
| with open(path, "wb") as f: |
| for chunk in r.iter_content(chunk_size=1024 * 1024): |
| if not chunk: |
| continue |
| total += len(chunk) |
| if total > max_bytes: |
| raise RuntimeError(f"File too large (>{max_bytes} bytes).") |
| f.write(chunk) |
| return path |
| except Exception: |
| try: |
| os.remove(path) |
| except Exception: |
| pass |
| raise |
|
|
|
|
| def _looks_like_youtube(url: str) -> bool: |
| u = (url or "").lower() |
| return ("youtube.com" in u) or ("youtu.be" in u) |
|
|
|
|
| def _download_youtube_to_tempfile(url: str, max_bytes: int = 500 * 1024 * 1024) -> str: |
| |
| |
| tmp_dir = tempfile.mkdtemp(prefix="yt_") |
| outtmpl = os.path.join(tmp_dir, "video.%(ext)s") |
| has_ffmpeg = shutil.which("ffmpeg") is not None |
| |
| formats_to_try = ["best[ext=mp4]/best"] |
| if has_ffmpeg: |
| formats_to_try.extend( |
| [ |
| "bv*+ba/b", |
| "bestvideo+bestaudio/best", |
| "best", |
| ] |
| ) |
| youtube_clients_to_try = [ |
| None, |
| ["android"], |
| ["tv"], |
| ["web"], |
| ["android", "web"], |
| ] |
| try: |
| last_err: Optional[Exception] = None |
| for client_list in youtube_clients_to_try: |
| for fmt in formats_to_try: |
| ydl_opts = { |
| "outtmpl": outtmpl, |
| "format": fmt, |
| "merge_output_format": "mp4", |
| "quiet": True, |
| "noprogress": True, |
| "retries": 3, |
| "socket_timeout": 30, |
| } |
| if client_list: |
| ydl_opts["extractor_args"] = {"youtube": {"player_client": client_list}} |
| try: |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
| info = ydl.extract_info(url, download=True) |
|
|
| |
| candidates: List[str] = [] |
| if isinstance(info, dict): |
| req = info.get("requested_downloads") |
| if isinstance(req, list): |
| for item in req: |
| if isinstance(item, dict) and item.get("filepath"): |
| candidates.append(str(item["filepath"])) |
| if info.get("_filename"): |
| candidates.append(str(info["_filename"])) |
|
|
| |
| produced: List[str] = [] |
| for name in os.listdir(tmp_dir): |
| p = os.path.join(tmp_dir, name) |
| if os.path.isfile(p): |
| produced.append(p) |
| produced.sort(key=lambda p: os.path.getsize(p), reverse=True) |
| candidates.extend(produced) |
|
|
| out_path = next((p for p in candidates if p and os.path.exists(p)), None) |
| if not out_path: |
| raise RuntimeError("YouTube download failed (no output file produced).") |
|
|
| size = os.path.getsize(out_path) |
| if size > max_bytes: |
| raise RuntimeError(f"Downloaded video too large (>{max_bytes} bytes).") |
|
|
| |
| fd, final_path = tempfile.mkstemp(suffix=os.path.splitext(out_path)[1] or ".mp4") |
| os.close(fd) |
| with open(out_path, "rb") as src, open(final_path, "wb") as dst: |
| dst.write(src.read()) |
| return final_path |
| except Exception as e: |
| last_err = e |
| continue |
|
|
| if not has_ffmpeg: |
| raise RuntimeError( |
| "YouTube download failed. This environment does not have ffmpeg installed, " |
| "so only single-file MP4 formats can be used. Try a different video URL or install ffmpeg. " |
| f"Last error: {last_err}" |
| ) |
| raise RuntimeError(f"YouTube download failed for all format fallbacks: {last_err}") |
| finally: |
| try: |
| for name in os.listdir(tmp_dir): |
| p = os.path.join(tmp_dir, name) |
| try: |
| os.remove(p) |
| except Exception: |
| pass |
| os.rmdir(tmp_dir) |
| except Exception: |
| pass |
|
|
|
|
| def _pil_to_data_uri(img: Image.Image, max_side: int = 720, quality: int = 85) -> str: |
| im = img.convert("RGB") |
| w, h = im.size |
| scale = min(1.0, float(max_side) / float(max(w, h))) |
| if scale < 1.0: |
| im = im.resize((int(w * scale), int(h * scale)), Image.Resampling.BICUBIC) |
| buf = BytesIO() |
| im.save(buf, format="JPEG", quality=quality, optimize=True) |
| b64 = base64.b64encode(buf.getvalue()).decode("ascii") |
| return f"data:image/jpeg;base64,{b64}" |
|
|
|
|
| def _sample_random_frames(video_path: str, num_frames: int, seed: Optional[int] = None) -> List[Tuple[int, Image.Image]]: |
| if seed is not None: |
| rng = random.Random(seed) |
| else: |
| rng = random.Random() |
|
|
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise RuntimeError("Failed to open video. Unsupported codec/container or corrupted file.") |
|
|
| try: |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) |
| frames: List[Tuple[int, Image.Image]] = [] |
|
|
| if frame_count > 0: |
| idxs = sorted(set(rng.randint(0, max(frame_count - 1, 0)) for _ in range(num_frames))) |
| for idx in idxs: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ok, frame = cap.read() |
| if not ok or frame is None: |
| continue |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append((idx, Image.fromarray(rgb))) |
| return frames |
|
|
| |
| seen = 0 |
| reservoir: List[Tuple[int, Image.Image]] = [] |
| idx = 0 |
| while True: |
| ok, frame = cap.read() |
| if not ok or frame is None: |
| break |
| if seen < num_frames: |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| reservoir.append((idx, Image.fromarray(rgb))) |
| else: |
| j = rng.randint(0, seen) |
| if j < num_frames: |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| reservoir[j] = (idx, Image.fromarray(rgb)) |
| seen += 1 |
| idx += 1 |
| return reservoir |
| finally: |
| cap.release() |
|
|
|
|
| def analyze_video_from_url(url: str, num_frames: int = 5, seed: Optional[int] = None) -> Dict[str, Any]: |
| if not url or not url.strip(): |
| raise RuntimeError("Please provide a valid video URL.") |
|
|
| clean_url = url.strip() |
| tmp_path = _download_youtube_to_tempfile(clean_url) if _looks_like_youtube(clean_url) else _download_to_tempfile(clean_url, suffix=".mp4") |
| try: |
| sampled = _sample_random_frames(tmp_path, num_frames=max(1, int(num_frames)), seed=seed) |
| if not sampled: |
| raise RuntimeError("Could not decode any frames from the video.") |
|
|
| per_frame = [] |
| fake_probs = [] |
| for frame_idx, pil_img in sampled: |
| res = _analyze_pil_image(pil_img) |
| per_frame.append( |
| { |
| "frame_index": int(frame_idx), |
| "prediction": res["prediction"], |
| "confidence": res["confidence"], |
| "p_real": res["probabilities"][0], |
| "p_fake": res["probabilities"][1], |
| "quick": res["quick"], |
| "image_data_uri": _pil_to_data_uri(pil_img), |
| } |
| ) |
| fake_probs.append(res["probabilities"][1]) |
|
|
| avg_fake = float(np.mean(fake_probs)) |
| fake_votes = sum(1 for x in per_frame if x["prediction"] == "FAKE") |
| real_votes = len(per_frame) - fake_votes |
| overall = "FAKE" if avg_fake >= 0.5 else "REAL" |
|
|
| report_md = "## Video Analysis Report\n\n" |
| report_md += f"**Frames analyzed:** {len(per_frame)}\n\n" |
| report_md += f"**Overall (avg FAKE prob (p_fake)):** {overall} (avg p_fake = {avg_fake:.1%})\n\n" |
| report_md += f"**Vote summary:** FAKE {fake_votes} / REAL {real_votes}\n\n" |
| report_md += "### Frame breakdown\n" |
| for row in sorted(per_frame, key=lambda r: r["frame_index"]): |
| report_md += ( |
| f"- Frame `{row['frame_index']}`: **{row['prediction']}** " |
| f"(conf {row['confidence']:.1%}, p_fake {row['p_fake']:.1%})\n" |
| ) |
|
|
| return { |
| "overall": overall, |
| "avg_p_fake": avg_fake, |
| "frames": per_frame, |
| "report_markdown": report_md, |
| "is_youtube": _looks_like_youtube(clean_url), |
| "model": os.path.basename(current_model_path) if current_model_path else None, |
| } |
| finally: |
| try: |
| os.remove(tmp_path) |
| except Exception: |
| pass |
|
|
|
|
| def analyze_video_from_path(video_path: str, num_frames: int = 5, seed: Optional[int] = None) -> Dict[str, Any]: |
| sampled = _sample_random_frames(video_path, num_frames=max(1, int(num_frames)), seed=seed) |
| if not sampled: |
| raise RuntimeError("Could not decode any frames from the video.") |
|
|
| per_frame = [] |
| fake_probs = [] |
| for frame_idx, pil_img in sampled: |
| res = _analyze_pil_image(pil_img) |
| per_frame.append( |
| { |
| "frame_index": int(frame_idx), |
| "prediction": res["prediction"], |
| "confidence": res["confidence"], |
| "p_real": res["probabilities"][0], |
| "p_fake": res["probabilities"][1], |
| "quick": res["quick"], |
| "image_data_uri": _pil_to_data_uri(pil_img), |
| } |
| ) |
| fake_probs.append(res["probabilities"][1]) |
|
|
| avg_fake = float(np.mean(fake_probs)) |
| fake_votes = sum(1 for x in per_frame if x["prediction"] == "FAKE") |
| real_votes = len(per_frame) - fake_votes |
| overall = "FAKE" if avg_fake >= 0.5 else "REAL" |
|
|
| report_md = "## Video Analysis Report\n\n" |
| report_md += f"**Frames analyzed:** {len(per_frame)}\n\n" |
| report_md += f"**Overall (avg FAKE prob (p_fake)):** {overall} (avg p_fake = {avg_fake:.1%})\n\n" |
| report_md += f"**Vote summary:** FAKE {fake_votes} / REAL {real_votes}\n\n" |
| report_md += "### Frame breakdown\n" |
| for row in sorted(per_frame, key=lambda r: r["frame_index"]): |
| report_md += ( |
| f"- Frame `{row['frame_index']}`: **{row['prediction']}** " |
| f"(conf {row['confidence']:.1%}, p_fake {row['p_fake']:.1%})\n" |
| ) |
|
|
| return { |
| "overall": overall, |
| "avg_p_fake": avg_fake, |
| "frames": per_frame, |
| "report_markdown": report_md, |
| "is_youtube": False, |
| "model": os.path.basename(current_model_path) if current_model_path else None, |
| "video_source_type": "upload", |
| } |
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates")) |
|
|
| app = FastAPI(title="AI Image Detector", version="1.0") |
|
|
| static_dir = os.path.join(BASE_DIR, "static") |
| if os.path.isdir(static_dir): |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| def index(request: Request): |
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": "Select a model and click 'Load Model' to initialize the AI detector", |
| "image_result": None, |
| "url_result": None, |
| "video_result": None, |
| "history": list(history), |
| }, |
| ) |
|
|
|
|
| @app.post("/load_model", response_class=HTMLResponse) |
| def load_model(request: Request, model_path: str = Form(...)): |
| status = ensure_model_loaded(model_path) |
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": status, |
| "image_result": None, |
| "url_result": None, |
| "video_result": None, |
| "history": list(history), |
| }, |
| ) |
|
|
|
|
| @app.post("/analyze_image_upload", response_class=HTMLResponse) |
| async def analyze_image_upload(request: Request, file: UploadFile = File(...)): |
| status = ensure_model_loaded() |
| image_result = None |
| try: |
| data = await file.read() |
| image = Image.open(BytesIO(data)) |
| image_result = _analyze_pil_image(image) |
| image_result["image_data_uri"] = _pil_to_data_uri(image) |
| image_result["filename"] = file.filename |
| _record_history( |
| { |
| "ts": _now_iso(), |
| "type": "image_upload", |
| "source": file.filename, |
| "model": image_result.get("model"), |
| "prediction": image_result.get("prediction"), |
| "confidence": image_result.get("confidence"), |
| } |
| ) |
| except Exception as e: |
| image_result = {"error": str(e)} |
|
|
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": status, |
| "image_result": image_result, |
| "url_result": None, |
| "video_result": None, |
| "history": list(history), |
| }, |
| ) |
|
|
|
|
| @app.post("/analyze_image_url", response_class=HTMLResponse) |
| def analyze_image_url(request: Request, url: str = Form(...)): |
| status = ensure_model_loaded() |
| url_result = None |
| try: |
| prediction, confidence, probabilities, img = detector.predict_from_url(url.strip()) |
| url_result = { |
| "prediction": prediction, |
| "confidence": float(confidence), |
| "probabilities": [float(probabilities[0]), float(probabilities[1])], |
| "markdown": _format_prediction_markdown(prediction, float(confidence), probabilities), |
| "quick": _quick_status(prediction, float(confidence)), |
| "url": url.strip(), |
| "image_data_uri": _pil_to_data_uri(img), |
| "model": os.path.basename(current_model_path) if current_model_path else None, |
| } |
| _record_history( |
| { |
| "ts": _now_iso(), |
| "type": "image_url", |
| "source": url.strip(), |
| "model": url_result.get("model"), |
| "prediction": url_result.get("prediction"), |
| "confidence": url_result.get("confidence"), |
| } |
| ) |
| except Exception as e: |
| url_result = {"error": str(e), "url": url.strip()} |
|
|
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": status, |
| "image_result": None, |
| "url_result": url_result, |
| "video_result": None, |
| "history": list(history), |
| }, |
| ) |
|
|
|
|
| @app.post("/analyze_video_url", response_class=HTMLResponse) |
| def analyze_video_url(request: Request, video_url: str = Form(...), num_frames: int = Form(5)): |
| status = ensure_model_loaded() |
| video_result = None |
| try: |
| n = int(num_frames) |
| if n < 1: |
| n = 1 |
| if n > 50: |
| n = 50 |
| video_result = analyze_video_from_url(video_url.strip(), num_frames=n) |
| video_result["video_url"] = video_url.strip() |
| video_result["video_source_type"] = "url" |
| _record_history( |
| { |
| "ts": _now_iso(), |
| "type": "video_url", |
| "source": video_url.strip(), |
| "model": video_result.get("model"), |
| "prediction": video_result.get("overall"), |
| "confidence": 1.0 - float(video_result.get("avg_p_fake", 0.0)) |
| if video_result.get("overall") == "REAL" |
| else float(video_result.get("avg_p_fake", 0.0)), |
| } |
| ) |
| except Exception as e: |
| video_result = {"error": str(e), "video_url": video_url.strip()} |
|
|
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": status, |
| "image_result": None, |
| "url_result": None, |
| "video_result": video_result, |
| "history": list(history), |
| }, |
| ) |
|
|
|
|
| @app.post("/analyze_video_upload", response_class=HTMLResponse) |
| async def analyze_video_upload(request: Request, file: UploadFile = File(...), num_frames: int = Form(5)): |
| status = ensure_model_loaded() |
| video_result = None |
| tmp_path = None |
| try: |
| n = int(num_frames) |
| if n < 1: |
| n = 1 |
| if n > 50: |
| n = 50 |
|
|
| suffix = os.path.splitext(file.filename or "")[1] or ".mp4" |
| fd, tmp_path = tempfile.mkstemp(suffix=suffix) |
| os.close(fd) |
| with open(tmp_path, "wb") as f: |
| while True: |
| chunk = await file.read(1024 * 1024) |
| if not chunk: |
| break |
| f.write(chunk) |
|
|
| video_result = analyze_video_from_path(tmp_path, num_frames=n) |
| video_result["video_filename"] = file.filename |
| _record_history( |
| { |
| "ts": _now_iso(), |
| "type": "video_upload", |
| "source": file.filename, |
| "model": video_result.get("model"), |
| "prediction": video_result.get("overall"), |
| "confidence": 1.0 - float(video_result.get("avg_p_fake", 0.0)) |
| if video_result.get("overall") == "REAL" |
| else float(video_result.get("avg_p_fake", 0.0)), |
| } |
| ) |
| except Exception as e: |
| video_result = {"error": str(e), "video_filename": getattr(file, "filename", None), "video_source_type": "upload"} |
| finally: |
| if tmp_path: |
| try: |
| os.remove(tmp_path) |
| except Exception: |
| pass |
|
|
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "available_models": get_available_models(), |
| "current_model_path": current_model_path, |
| "model_status": status, |
| "image_result": None, |
| "url_result": None, |
| "video_result": video_result, |
| "history": list(history), |
| }, |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| print("Starting FastAPI AI Image Detector...") |
| print(f"Available models: {get_available_models()}") |
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", "7860")), reload=False) |