Spaces:
Paused
Paused
| import time | |
| import requests | |
| import torch | |
| import torchvision.transforms as T | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from PIL import Image | |
| from io import BytesIO | |
| import os | |
| class ImageAnalyzer: | |
| def __init__(self): | |
| # Load real pre-trained model | |
| self.weights = ResNet50_Weights.DEFAULT | |
| self.model = resnet50(weights=self.weights) | |
| self.model.eval() | |
| self.preprocess = self.weights.transforms() | |
| def analyze(self, image_input): | |
| img_pil = None | |
| # 1. Handle URL Input | |
| if isinstance(image_input, str) and image_input.startswith("http"): | |
| try: | |
| # Use requests for simple image fetch, fallback to Selenium if needed | |
| # For simplicity in this demo, requests is often enough for direct image links | |
| # But to maintain "browser" behavior, let's use requests with User-Agent | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| response = requests.get(image_input, headers=headers, stream=True, timeout=10) | |
| if response.status_code == 200: | |
| img_pil = Image.open(BytesIO(response.content)).convert('RGB') | |
| else: | |
| return f"β Failed to fetch image. Status: {response.status_code}" | |
| except Exception as e: | |
| return f"β Error loading URL: {str(e)}" | |
| # 2. Handle Upload Input (numpy array from Gradio) | |
| elif image_input is not None: | |
| # Gradio passes images as numpy array, convert to PIL | |
| img_pil = Image.fromarray(image_input).convert('RGB') | |
| if img_pil is None: | |
| return "β No valid image provided." | |
| # 3. Perform Inference | |
| try: | |
| batch = self.preprocess(img_pil).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = self.model(batch).squeeze(0).softmax(0) | |
| class_id = prediction.argmax().item() | |
| score = prediction[class_id].item() | |
| category_name = self.weights.meta["categories"][class_id] | |
| # Get Top 3 | |
| top3_prob, top3_id = torch.topk(prediction, 3) | |
| top_results = [] | |
| for i in range(3): | |
| cat = self.weights.meta["categories"][top3_id[i].item()] | |
| prob = top3_prob[i].item() | |
| top_results.append(f"{i+1}. **{cat}** ({prob:.1%})") | |
| result_text = "\n".join(top_results) | |
| return f"""πΈ **Image Analysis Result** | |
| π― **Top Prediction**: {category_name} | |
| π **Confidence**: {score:.1%} | |
| π **Top 3 Candidates**: | |
| {result_text} | |
| *Analysis performed by ResNet50*""" | |
| except Exception as e: | |
| return f"β Inference Failed: {str(e)}" | |