import time import requests import torch import torchvision.transforms as T from torchvision.models import resnet50, ResNet50_Weights from PIL import Image from io import BytesIO import os class ImageAnalyzer: def __init__(self): # Load real pre-trained model self.weights = ResNet50_Weights.DEFAULT self.model = resnet50(weights=self.weights) self.model.eval() self.preprocess = self.weights.transforms() def analyze(self, image_input): img_pil = None # 1. Handle URL Input if isinstance(image_input, str) and image_input.startswith("http"): try: # Use requests for simple image fetch, fallback to Selenium if needed # For simplicity in this demo, requests is often enough for direct image links # But to maintain "browser" behavior, let's use requests with User-Agent headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(image_input, headers=headers, stream=True, timeout=10) if response.status_code == 200: img_pil = Image.open(BytesIO(response.content)).convert('RGB') else: return f"❌ Failed to fetch image. Status: {response.status_code}" except Exception as e: return f"❌ Error loading URL: {str(e)}" # 2. Handle Upload Input (numpy array from Gradio) elif image_input is not None: # Gradio passes images as numpy array, convert to PIL img_pil = Image.fromarray(image_input).convert('RGB') if img_pil is None: return "❌ No valid image provided." # 3. Perform Inference try: batch = self.preprocess(img_pil).unsqueeze(0) with torch.no_grad(): prediction = self.model(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = self.weights.meta["categories"][class_id] # Get Top 3 top3_prob, top3_id = torch.topk(prediction, 3) top_results = [] for i in range(3): cat = self.weights.meta["categories"][top3_id[i].item()] prob = top3_prob[i].item() top_results.append(f"{i+1}. **{cat}** ({prob:.1%})") result_text = "\n".join(top_results) return f"""📸 **Image Analysis Result** 🎯 **Top Prediction**: {category_name} 📊 **Confidence**: {score:.1%} 🏆 **Top 3 Candidates**: {result_text} *Analysis performed by ResNet50*""" except Exception as e: return f"❌ Inference Failed: {str(e)}"