File size: 2,847 Bytes
cd42f59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import time
import requests
import torch
import torchvision.transforms as T
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
from io import BytesIO
import os

class ImageAnalyzer:
    def __init__(self):

        # Load real pre-trained model
        self.weights = ResNet50_Weights.DEFAULT
        self.model = resnet50(weights=self.weights)
        self.model.eval()
        self.preprocess = self.weights.transforms()


    def analyze(self, image_input):
        img_pil = None
        
        # 1. Handle URL Input
        if isinstance(image_input, str) and image_input.startswith("http"):
            try:

                # Use requests for simple image fetch, fallback to Selenium if needed
                # For simplicity in this demo, requests is often enough for direct image links
                # But to maintain "browser" behavior, let's use requests with User-Agent
                headers = {'User-Agent': 'Mozilla/5.0'}
                response = requests.get(image_input, headers=headers, stream=True, timeout=10)
                if response.status_code == 200:
                    img_pil = Image.open(BytesIO(response.content)).convert('RGB')
                else:
                    return f"❌ Failed to fetch image. Status: {response.status_code}"
            except Exception as e:
                return f"❌ Error loading URL: {str(e)}"
        
        # 2. Handle Upload Input (numpy array from Gradio)
        elif image_input is not None:
            # Gradio passes images as numpy array, convert to PIL
            img_pil = Image.fromarray(image_input).convert('RGB')
        
        if img_pil is None:
            return "❌ No valid image provided."

        # 3. Perform Inference
        try:
            batch = self.preprocess(img_pil).unsqueeze(0)
            
            with torch.no_grad():
                prediction = self.model(batch).squeeze(0).softmax(0)
            
            class_id = prediction.argmax().item()
            score = prediction[class_id].item()
            category_name = self.weights.meta["categories"][class_id]
            
            # Get Top 3
            top3_prob, top3_id = torch.topk(prediction, 3)
            top_results = []
            for i in range(3):
                cat = self.weights.meta["categories"][top3_id[i].item()]
                prob = top3_prob[i].item()
                top_results.append(f"{i+1}. **{cat}** ({prob:.1%})")
            
            result_text = "\n".join(top_results)
            
            return f"""πŸ“Έ **Image Analysis Result**

🎯 **Top Prediction**: {category_name}
πŸ“Š **Confidence**: {score:.1%}

πŸ† **Top 3 Candidates**:
{result_text}

*Analysis performed by ResNet50*"""
            
        except Exception as e:
            return f"❌ Inference Failed: {str(e)}"