# utils.py from ultralytics import YOLO import os import cv2 import numpy as np from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import torch from dotenv import load_dotenv from pathlib import Path load_dotenv() # βœ… Load .env from this backend folder load_dotenv(Path(__file__).resolve().parent / ".env") # βœ… Now fetch your Gemini key GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") if not GEMINI_API_KEY: print("❌ GOOGLE_API_KEY not found! Check your .env file path.") else: print("βœ… GOOGLE_API_KEY loaded successfully.") # ======================================== # YOLO Object Detection # ======================================== yolo_model = YOLO("yolov8n.pt") def detect_objects(image_path): """Detect objects in image using YOLO""" results = yolo_model(image_path) detected_objects = [] for result in results: for cls in result.boxes.cls: class_idx = int(cls) class_name = result.names[class_idx] detected_objects.append(class_name) # Remove duplicates return list(set(detected_objects)) # ======================================== # ViT Model for AI Detection (used in main.py) # ======================================== MODEL_PATH = os.path.join(os.path.dirname(__file__), "vit-ai-vs-real-model") processor = AutoImageProcessor.from_pretrained(MODEL_PATH) model = AutoModelForImageClassification.from_pretrained(MODEL_PATH) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # ======================================== # Grad-CAM Visualization # ======================================== def generate_gradcam(image_path, output_path="gradcam_result.jpg"): """ Generates Grad-CAM heatmap for the ViT model and saves it as an image file. """ image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt").to(device) # Enable gradient computation inputs["pixel_values"].requires_grad_(True) # Lists to capture attention weights and gradients attention_weights = [] attention_grads = [] def forward_hook(module, input, output): """Capture attention weights during forward pass""" if isinstance(output, tuple) and len(output) > 1: attention_weights.append(output[1]) else: attention_weights.append(None) def backward_hook(module, grad_input, grad_output): """Capture gradients during backward pass""" if grad_output[0] is not None: attention_grads.append(grad_output[0]) # Register hooks on the last attention layer last_attn = model.vit.encoder.layer[-1].attention.attention fwd_handle = last_attn.register_forward_hook(forward_hook) bwd_handle = last_attn.register_full_backward_hook(backward_hook) try: print("πŸ” Running forward pass...") # Forward pass outputs = model(**inputs) pred_class = torch.argmax(outputs.logits, dim=-1).item() score = outputs.logits[:, pred_class] print(f"🎯 Predicted class: {pred_class}, Score: {score.item():.4f}") # Backward pass model.zero_grad() score.backward() print(f"πŸ“Š Captured {len(attention_weights)} attention weights") print(f"πŸ“Š Captured {len(attention_grads)} gradients") # Check if we captured valid attention data if len(attention_weights) == 0 or attention_weights[0] is None: print("⚠️ No attention weights captured - using alternative method") # Alternative: Use the last hidden state's gradients patch_grad = inputs["pixel_values"].grad if patch_grad is not None: # Average over channels and batch cam = patch_grad.abs().mean(dim=[0, 1]).detach().cpu().numpy() # Resize to reasonable spatial dimensions import torch.nn.functional as F h, w = cam.shape target_size = 14 # ViT typically uses 14x14 patches cam_tensor = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0) cam = F.interpolate(cam_tensor, size=(target_size, target_size), mode='bilinear', align_corners=False) cam = cam.squeeze().numpy() else: print("⚠️ No gradients found - using center-focused fallback") # Create a simple center-focused heatmap as fallback cam = np.ones((14, 14)) center = 7 for i in range(14): for j in range(14): dist = np.sqrt((i - center)**2 + (j - center)**2) cam[i, j] = max(0, 1 - dist / 10) else: print("βœ… Processing attention weights...") # Get attention weights and gradients attn = attention_weights[0] # [batch, heads, seq_len, seq_len] if len(attention_grads) > 0: grad = attention_grads[0] # Average over heads weights = grad.mean(dim=1).squeeze(0) # [seq_len, seq_len] attn_map = attn.mean(dim=1).squeeze(0) # [seq_len, seq_len] # Weight the attention map by gradients cam = (weights * attn_map).sum(dim=0).detach().cpu().numpy() else: # Just use attention without gradients attn_map = attn.mean(dim=1).squeeze(0) # [seq_len, seq_len] cam = attn_map.mean(dim=0).detach().cpu().numpy() # Remove CLS token (first position) and reshape to spatial dimensions cam = cam[1:] # Remove CLS token size = int(np.sqrt(len(cam))) cam = cam[:size*size].reshape(size, size) # Normalize cam = np.maximum(cam, 0) if cam.max() > 0: cam = cam / cam.max() else: cam = np.ones_like(cam) * 0.5 print(f"πŸ“ CAM shape: {cam.shape}, min: {cam.min():.4f}, max: {cam.max():.4f}") except Exception as e: print(f"❌ Error during Grad-CAM computation: {e}") import traceback traceback.print_exc() # Fallback: create simple heatmap print("πŸ”„ Using fallback heatmap") cam = np.ones((14, 14)) * 0.5 finally: # Remove hooks fwd_handle.remove() bwd_handle.remove() # Resize CAM to match input image size cam = cv2.resize(cam, (image.size[0], image.size[1])) # Create heatmap heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Overlay on original image overlay = np.array(image) * 0.6 + heatmap * 0.4 overlay = np.uint8(overlay) # Save Grad-CAM image cv2.imwrite(output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) print("🟒 Grad-CAM generated!") print(f"πŸ’Ύ Saved at: {os.path.abspath(output_path)}") return output_path def generate_gemini_summary( original_image_path: str = None, grad_cam_image_path: str = None, classification: str = None, probability_percent: float = None, objects: list = None ) -> str: """ Unified Gemini summary generator for both vision and text-only models. Works for Gemini 1.5 (vision) and text-only Gemini models. """ try: if not GEMINI_MODEL: raise ValueError("No Gemini model available") print(f"πŸ€– Using Gemini model: {GEMINI_MODEL}") gemini_model_instance = genai.GenerativeModel(GEMINI_MODEL) # Detect if model supports vision is_vision_model = any(x in GEMINI_MODEL.lower() for x in ["1.5", "2.0", "2.5", "vision"]) # Vision model β†’ send both images if original_image_path and grad_cam_image_path and is_vision_model: with open(original_image_path, "rb") as f1, open(grad_cam_image_path, "rb") as f2: img_original = {"mime_type": "image/jpeg", "data": f1.read()} img_gradcam = {"mime_type": "image/jpeg", "data": f2.read()} prompt = f""" You are an expert in AI image forensics and explainable AI (XAI). Two images are provided: 1. The original image. 2. The Grad-CAM heatmap (red/yellow = strong focus regions). The detection model's result: β€’ Classification: {classification or "Unknown"} β€’ Probability: {probability_percent or 0:.2f}% Write a concise 2–3 sentence explanation describing: - Which regions were highlighted by Grad-CAM. - Why those regions indicate the model’s decision. - How the confidence reflects these cues. """ response = gemini_model_instance.generate_content([prompt, img_original, img_gradcam]) return response.text.strip() # Text-only fallback else: prompt = f""" The model detected these objects: {', '.join(objects or [])}. Classification result: {classification or "Unknown"} ({probability_percent or 0:.2f}% confidence). Write a short 2–3 sentence explanation of why the image might appear {classification.lower() if classification else 'AI-generated or real'}. """ response = gemini_model_instance.generate_content(prompt) return response.text.strip() except Exception as e: print(f"⚠️ Gemini summary error: {e}") err = str(e) if "429" in err or "quota" in err.lower(): print("πŸ’‘ Quota exceeded β€” fallback summary used.") if not objects: return "This image likely contains a few simple visual features, but detailed analysis is unavailable." elif len(objects) == 1: return f"This image features a {objects[0]}." elif len(objects) == 2: return f"This image shows a {objects[0]} and a {objects[1]}." else: return f"This image contains {', '.join(objects[:-1])}, and {objects[-1]}." else: return f"Unable to generate Gemini summary. Possible reason: {err}"