Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| import torch | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from PIL import Image | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| # Import model architecture | |
| from model_architecture import create_model | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Constants | |
| REPO_ID = "karthikeya09/smart_image_recognation" | |
| MODEL_FILE = "best_model.pth" | |
| CLASSES = ['glass', 'metal', 'non-recyclable', 'organic', 'paper', 'plastic'] | |
| CLASS_INFO = { | |
| 'glass': {'emoji': 'π', 'color': '#28a745', 'description': 'Glass bottles, jars'}, | |
| 'metal': {'emoji': 'π', 'color': '#6c757d', 'description': 'Cans, foil, batteries'}, | |
| 'non-recyclable': {'emoji': 'β«', 'color': '#343a40', 'description': 'Mixed/contaminated waste'}, | |
| 'organic': {'emoji': 'π’', 'color': '#20c997', 'description': 'Food waste, plant matter'}, | |
| 'paper': {'emoji': 'π', 'color': '#8B4513', 'description': 'Newspapers, cardboard'}, | |
| 'plastic': {'emoji': 'π΅', 'color': '#007bff', 'description': 'Bottles, bags, containers'} | |
| } | |
| # Load Model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"π₯ Downloading model from {REPO_ID}...") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) | |
| print("π§ Loading model architecture...") | |
| model = create_model(architecture='mobilenet_v2', num_classes=len(CLASSES), pretrained=False) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| print(f"β Model loaded on {device}") | |
| # Transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(img): | |
| """Run prediction on an image""" | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| input_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| predicted_class = CLASSES[predicted.item()] | |
| confidence_score = confidence.item() * 100 | |
| all_probs = {CLASSES[i]: round(probabilities[0][i].item() * 100, 2) | |
| for i in range(len(CLASSES))} | |
| return { | |
| 'class': predicted_class, | |
| 'confidence': round(confidence_score, 2), | |
| 'info': CLASS_INFO[predicted_class], | |
| 'all_probabilities': all_probs | |
| } | |
| def home(): | |
| return jsonify({ | |
| 'status': 'running', | |
| 'model': REPO_ID, | |
| 'classes': CLASSES, | |
| 'message': 'Smart Waste Classifier API' | |
| }) | |
| def predict_endpoint(): | |
| """Handle prediction requests""" | |
| try: | |
| # Handle JSON with base64 image | |
| if request.is_json: | |
| data = request.get_json() | |
| image_data = data.get('image_base64') or data.get('image') | |
| if not image_data: | |
| return jsonify({'success': False, 'error': 'No image provided'}), 400 | |
| # Remove data URL prefix if present | |
| if ',' in image_data: | |
| image_data = image_data.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| # Handle form data with file upload | |
| elif 'image' in request.files: | |
| file = request.files['image'] | |
| img = Image.open(file.stream) | |
| else: | |
| return jsonify({'success': False, 'error': 'No image provided'}), 400 | |
| result = predict(img) | |
| return jsonify({'success': True, 'prediction': result}) | |
| except Exception as e: | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def health(): | |
| return jsonify({'status': 'healthy', 'model_loaded': model is not None}) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |