File size: 4,188 Bytes
cd4dafb
51087b5
 
cd4dafb
51087b5
 
cd4dafb
 
 
 
51087b5
cd4dafb
 
51087b5
 
 
 
cd4dafb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51087b5
cd4dafb
 
 
 
 
 
 
 
51087b5
cd4dafb
51087b5
 
cd4dafb
51087b5
 
cd4dafb
51087b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd4dafb
51087b5
 
 
 
 
 
 
 
 
 
cd4dafb
51087b5
 
 
 
 
 
cd4dafb
51087b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import io
import base64
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download

# Import model architecture
from model_architecture import create_model

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Constants
REPO_ID = "karthikeya09/smart_image_recognation"
MODEL_FILE = "best_model.pth"
CLASSES = ['glass', 'metal', 'non-recyclable', 'organic', 'paper', 'plastic']
CLASS_INFO = {
    'glass': {'emoji': 'πŸ’š', 'color': '#28a745', 'description': 'Glass bottles, jars'},
    'metal': {'emoji': 'πŸ”˜', 'color': '#6c757d', 'description': 'Cans, foil, batteries'},
    'non-recyclable': {'emoji': '⚫', 'color': '#343a40', 'description': 'Mixed/contaminated waste'},
    'organic': {'emoji': '🟒', 'color': '#20c997', 'description': 'Food waste, plant matter'},
    'paper': {'emoji': 'πŸ“„', 'color': '#8B4513', 'description': 'Newspapers, cardboard'},
    'plastic': {'emoji': 'πŸ”΅', 'color': '#007bff', 'description': 'Bottles, bags, containers'}
}

# Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"πŸ“₯ Downloading model from {REPO_ID}...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)

print("🧠 Loading model architecture...")
model = create_model(architecture='mobilenet_v2', num_classes=len(CLASSES), pretrained=False)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
print(f"βœ… Model loaded on {device}")

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict(img):
    """Run prediction on an image"""
    if img.mode != 'RGB':
        img = img.convert('RGB')
    
    input_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    predicted_class = CLASSES[predicted.item()]
    confidence_score = confidence.item() * 100
    
    all_probs = {CLASSES[i]: round(probabilities[0][i].item() * 100, 2) 
                 for i in range(len(CLASSES))}
    
    return {
        'class': predicted_class,
        'confidence': round(confidence_score, 2),
        'info': CLASS_INFO[predicted_class],
        'all_probabilities': all_probs
    }

@app.route('/')
def home():
    return jsonify({
        'status': 'running',
        'model': REPO_ID,
        'classes': CLASSES,
        'message': 'Smart Waste Classifier API'
    })

@app.route('/predict', methods=['POST'])
def predict_endpoint():
    """Handle prediction requests"""
    try:
        # Handle JSON with base64 image
        if request.is_json:
            data = request.get_json()
            image_data = data.get('image_base64') or data.get('image')
            
            if not image_data:
                return jsonify({'success': False, 'error': 'No image provided'}), 400
            
            # Remove data URL prefix if present
            if ',' in image_data:
                image_data = image_data.split(',')[1]
            
            image_bytes = base64.b64decode(image_data)
            img = Image.open(io.BytesIO(image_bytes))
        
        # Handle form data with file upload
        elif 'image' in request.files:
            file = request.files['image']
            img = Image.open(file.stream)
        
        else:
            return jsonify({'success': False, 'error': 'No image provided'}), 400
        
        result = predict(img)
        return jsonify({'success': True, 'prediction': result})
    
    except Exception as e:
        return jsonify({'success': False, 'error': str(e)}), 500

@app.route('/health')
def health():
    return jsonify({'status': 'healthy', 'model_loaded': model is not None})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)