import os import io import base64 import torch from flask import Flask, request, jsonify from flask_cors import CORS from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download # Import model architecture from model_architecture import create_model # Initialize Flask app app = Flask(__name__) CORS(app) # Enable CORS for all routes # Constants REPO_ID = "karthikeya09/smart_image_recognation" MODEL_FILE = "best_model.pth" CLASSES = ['glass', 'metal', 'non-recyclable', 'organic', 'paper', 'plastic'] CLASS_INFO = { 'glass': {'emoji': '💚', 'color': '#28a745', 'description': 'Glass bottles, jars'}, 'metal': {'emoji': '🔘', 'color': '#6c757d', 'description': 'Cans, foil, batteries'}, 'non-recyclable': {'emoji': '⚫', 'color': '#343a40', 'description': 'Mixed/contaminated waste'}, 'organic': {'emoji': '🟢', 'color': '#20c997', 'description': 'Food waste, plant matter'}, 'paper': {'emoji': '📄', 'color': '#8B4513', 'description': 'Newspapers, cardboard'}, 'plastic': {'emoji': '🔵', 'color': '#007bff', 'description': 'Bottles, bags, containers'} } # Load Model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"📥 Downloading model from {REPO_ID}...") model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) print("🧠 Loading model architecture...") model = create_model(architecture='mobilenet_v2', num_classes=len(CLASSES), pretrained=False) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() print(f"✅ Model loaded on {device}") # Transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(img): """Run prediction on an image""" if img.mode != 'RGB': img = img.convert('RGB') input_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) predicted_class = CLASSES[predicted.item()] confidence_score = confidence.item() * 100 all_probs = {CLASSES[i]: round(probabilities[0][i].item() * 100, 2) for i in range(len(CLASSES))} return { 'class': predicted_class, 'confidence': round(confidence_score, 2), 'info': CLASS_INFO[predicted_class], 'all_probabilities': all_probs } @app.route('/') def home(): return jsonify({ 'status': 'running', 'model': REPO_ID, 'classes': CLASSES, 'message': 'Smart Waste Classifier API' }) @app.route('/predict', methods=['POST']) def predict_endpoint(): """Handle prediction requests""" try: # Handle JSON with base64 image if request.is_json: data = request.get_json() image_data = data.get('image_base64') or data.get('image') if not image_data: return jsonify({'success': False, 'error': 'No image provided'}), 400 # Remove data URL prefix if present if ',' in image_data: image_data = image_data.split(',')[1] image_bytes = base64.b64decode(image_data) img = Image.open(io.BytesIO(image_bytes)) # Handle form data with file upload elif 'image' in request.files: file = request.files['image'] img = Image.open(file.stream) else: return jsonify({'success': False, 'error': 'No image provided'}), 400 result = predict(img) return jsonify({'success': True, 'prediction': result}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/health') def health(): return jsonify({'status': 'healthy', 'model_loaded': model is not None}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)