Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import io | |
| import os | |
| from flask import Flask, request, jsonify, render_template_string | |
| from flask_cors import CORS | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| CORS(app) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 | |
| MODEL_PATH = 'Model-79.85.pth' | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| IMAGE_SIZE = 224 | |
| PORT = int(os.environ.get("PORT", 7860)) # HF Spaces uses 7860 | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| logger.error(f"ERROR: Model file '{MODEL_PATH}' not found!") | |
| return None, {} | |
| try: | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model = models.mobilenet_v3_large(weights=None) | |
| in_features = 960 | |
| model.classifier = nn.Sequential( | |
| nn.Linear(in_features, 1280), | |
| nn.Hardswish(inplace=True), | |
| nn.Dropout(0.3, inplace=True), | |
| nn.Linear(1280, 640), | |
| nn.Hardswish(inplace=True), | |
| nn.Dropout(0.2, inplace=True), | |
| nn.Linear(640, 120) | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(DEVICE) | |
| model.eval() | |
| idx_to_class = checkpoint.get('idx_to_class', {}) | |
| class_names = {int(k): v for k, v in idx_to_class.items()} | |
| val_acc = checkpoint.get('best_top1', checkpoint.get('best_acc', 'N/A')) | |
| if isinstance(val_acc, (int, float)): | |
| logger.info(f"Model loaded! Classes: {len(class_names)}, Best Val Acc: {val_acc:.2f}%") | |
| else: | |
| logger.info(f"Model loaded! Classes: {len(class_names)}") | |
| return model, class_names | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None, {} | |
| model, IDX_TO_CLASS = load_model() | |
| CLASSES = [IDX_TO_CLASS[i] for i in range(len(IDX_TO_CLASS))] if IDX_TO_CLASS else [] | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict_with_tta(image, model, device): | |
| model.eval() | |
| transform_flip = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.RandomHorizontalFlip(p=1.0), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| orig_tensor = transform(image).unsqueeze(0).to(device) | |
| flip_tensor = transform_flip(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs_orig = torch.nn.functional.softmax(model(orig_tensor), dim=1) | |
| outputs_flip = torch.nn.functional.softmax(model(flip_tensor), dim=1) | |
| outputs = (outputs_orig + outputs_flip) / 2 | |
| confidence, predicted = torch.max(outputs, 1) | |
| return predicted.item(), float(confidence.item()) | |
| HTML_TEMPLATE = ''' | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no"> | |
| <meta name="mobile-web-app-capable" content="yes"> | |
| <title>Dog Breed Classifier</title> | |
| <style> | |
| * { box-sizing: border-box; -webkit-tap-highlight-color: transparent; } | |
| body { | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; | |
| margin: 0; | |
| padding: 15px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| min-height: 100vh; | |
| touch-action: manipulation; | |
| } | |
| .container { | |
| max-width: 600px; | |
| margin: 0 auto; | |
| background: white; | |
| border-radius: 20px; | |
| padding: 25px; | |
| box-shadow: 0 20px 60px rgba(0,0,0,0.3); | |
| } | |
| h1 { | |
| text-align: center; | |
| color: #333; | |
| margin: 0 0 10px 0; | |
| font-size: 28px; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: #666; | |
| margin-bottom: 25px; | |
| font-size: 16px; | |
| } | |
| .model-badge { | |
| text-align: center; | |
| background: #e3f2fd; | |
| padding: 10px; | |
| border-radius: 8px; | |
| margin-bottom: 20px; | |
| font-size: 14px; | |
| color: #1976d2; | |
| font-weight: 600; | |
| } | |
| .button-grid { | |
| display: grid; | |
| gap: 15px; | |
| margin-bottom: 20px; | |
| } | |
| .btn { | |
| background: #4CAF50; | |
| color: white; | |
| border: none; | |
| padding: 18px; | |
| border-radius: 12px; | |
| font-size: 18px; | |
| font-weight: 600; | |
| cursor: pointer; | |
| width: 100%; | |
| transition: transform 0.2s, box-shadow 0.2s; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| display: block; | |
| text-align: center; | |
| text-decoration: none; | |
| -webkit-touch-callout: none; | |
| -webkit-user-select: none; | |
| user-select: none; | |
| } | |
| .btn:active { | |
| transform: translateY(2px) scale(0.98); | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .btn-camera { background: #2196F3; } | |
| .btn-gallery { background: #FF9800; } | |
| input[type="file"] { | |
| position: absolute; | |
| opacity: 0; | |
| pointer-events: none; | |
| width: 0; | |
| height: 0; | |
| } | |
| .preview-container { | |
| margin-top: 20px; | |
| text-align: center; | |
| display: none; | |
| } | |
| #preview { | |
| max-width: 100%; | |
| max-height: 300px; | |
| border-radius: 12px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15); | |
| } | |
| .loading { | |
| display: none; | |
| text-align: center; | |
| margin-top: 20px; | |
| } | |
| .spinner { | |
| border: 5px solid #f3f3f3; | |
| border-top: 5px solid #3498db; | |
| border-radius: 50%; | |
| width: 50px; | |
| height: 50px; | |
| animation: spin 1s linear infinite; | |
| margin: 0 auto 15px auto; | |
| } | |
| @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } | |
| .result { | |
| margin-top: 20px; | |
| padding: 25px; | |
| background: #f8f9fa; | |
| border-radius: 15px; | |
| display: none; | |
| text-align: center; | |
| border: 2px solid #e9ecef; | |
| } | |
| .breed-name { | |
| font-size: 24px; | |
| font-weight: bold; | |
| margin: 10px 0; | |
| color: #333; | |
| line-height: 1.3; | |
| text-transform: capitalize; | |
| } | |
| .confidence { | |
| font-size: 20px; | |
| color: #666; | |
| font-weight: 500; | |
| } | |
| .top-k { | |
| margin-top: 15px; | |
| text-align: left; | |
| background: white; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border: 1px solid #e0e0e0; | |
| } | |
| .top-k-title { | |
| font-size: 14px; | |
| color: #999; | |
| margin-bottom: 8px; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .breed-item { | |
| display: flex; | |
| justify-content: space-between; | |
| padding: 5px 0; | |
| border-bottom: 1px solid #f0f0f0; | |
| font-size: 14px; | |
| } | |
| .breed-item:last-child { border-bottom: none; } | |
| .breed-prob { font-weight: 600; color: #2196F3; } | |
| .error { | |
| margin-top: 15px; | |
| padding: 15px; | |
| background: #fee; | |
| color: #c33; | |
| border-radius: 8px; | |
| border-left: 4px solid #c33; | |
| display: none; | |
| } | |
| .status-badge { | |
| text-align: center; | |
| padding: 8px; | |
| border-radius: 6px; | |
| margin-bottom: 15px; | |
| font-size: 13px; | |
| font-weight: 600; | |
| } | |
| .status-ok { background: #e8f5e9; color: #2e7d32; } | |
| .status-error { background: #ffebee; color: #c62828; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="status-badge {{ 'status-ok' if model_loaded else 'status-error' }}"> | |
| {{ '✓ Model Ready' if model_loaded else '✗ Model Not Loaded' }} | |
| </div> | |
| <h1>Dog Breed Classifier</h1> | |
| <div class="button-grid"> | |
| <label for="cameraInput" class="btn btn-camera"> | |
| 📷 Take Photo | |
| </label> | |
| <label for="galleryInput" class="btn btn-gallery"> | |
| 🖼️ Upload from Gallery | |
| </label> | |
| </div> | |
| <input type="file" id="cameraInput" accept="image/*" capture="environment" onchange="handleFile(this)"> | |
| <input type="file" id="galleryInput" accept="image/*" onchange="handleFile(this)"> | |
| <div class="preview-container" id="previewContainer"> | |
| <img id="preview" alt="Selected dog"> | |
| </div> | |
| <div class="loading" id="loading"> | |
| <div class="spinner"></div> | |
| <div>Identifying breed...</div> | |
| </div> | |
| <div class="result" id="result"> | |
| <div class="breed-name" id="breedName"></div> | |
| <div class="confidence" id="confidence"></div> | |
| <div class="top-k" id="topK" style="display: none;"> | |
| <div class="top-k-title">Also possible:</div> | |
| <div id="topKList"></div> | |
| </div> | |
| </div> | |
| <div class="error" id="error"></div> | |
| </div> | |
| <script> | |
| function handleFile(input) { | |
| const file = input.files[0]; | |
| if (!file) return; | |
| document.getElementById('error').style.display = 'none'; | |
| document.getElementById('result').style.display = 'none'; | |
| const reader = new FileReader(); | |
| reader.onload = function(e) { | |
| const preview = document.getElementById('preview'); | |
| preview.src = e.target.result; | |
| document.getElementById('previewContainer').style.display = 'block'; | |
| uploadImage(file); | |
| }; | |
| reader.readAsDataURL(file); | |
| input.value = ''; | |
| } | |
| function uploadImage(file) { | |
| document.getElementById('loading').style.display = 'block'; | |
| const formData = new FormData(); | |
| formData.append('image', file); | |
| fetch('/predict', { | |
| method: 'POST', | |
| body: formData, | |
| headers: { | |
| 'Accept': 'application/json', | |
| } | |
| }) | |
| .then(response => { | |
| if (!response.ok) throw new Error('Server error: ' + response.status); | |
| return response.json(); | |
| }) | |
| .then(data => { | |
| document.getElementById('loading').style.display = 'none'; | |
| if (data.error) { | |
| showError(data.error); | |
| } else { | |
| showResult(data); | |
| } | |
| }) | |
| .catch(error => { | |
| document.getElementById('loading').style.display = 'none'; | |
| console.error("Upload error:", error); | |
| showError("Failed to analyze: " + error.message); | |
| }); | |
| } | |
| function showResult(data) { | |
| document.getElementById('breedName').textContent = data.class.replace(/_/g, ' '); | |
| document.getElementById('confidence').textContent = | |
| `Confidence: ${(data.confidence * 100).toFixed(1)}%`; | |
| document.getElementById('result').style.display = 'block'; | |
| if (data.top3 && data.top3.length > 1) { | |
| const list = document.getElementById('topKList'); | |
| list.innerHTML = ''; | |
| data.top3.slice(1).forEach(item => { | |
| const div = document.createElement('div'); | |
| div.className = 'breed-item'; | |
| div.innerHTML = `<span>${item.class.replace(/_/g, ' ')}</span><span class="breed-prob">${(item.confidence * 100).toFixed(1)}%</span>`; | |
| list.appendChild(div); | |
| }); | |
| document.getElementById('topK').style.display = 'block'; | |
| } | |
| setTimeout(() => { | |
| document.getElementById('result').scrollIntoView({ behavior: 'smooth', block: 'nearest' }); | |
| }, 100); | |
| } | |
| function showError(msg) { | |
| const errorDiv = document.getElementById('error'); | |
| errorDiv.textContent = msg; | |
| errorDiv.style.display = 'block'; | |
| errorDiv.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| ''' | |
| def index(): | |
| return render_template_string(HTML_TEMPLATE, model_loaded=model is not None) | |
| def predict(): | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 500 | |
| if 'image' not in request.files: | |
| return jsonify({'error': 'No image provided'}), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| try: | |
| image_bytes = file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| pred_idx, confidence = predict_with_tta(image, model, DEVICE) | |
| breed_name = IDX_TO_CLASS.get(pred_idx, "Unknown") | |
| with torch.no_grad(): | |
| tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| outputs = torch.nn.functional.softmax(model(tensor), dim=1) | |
| probs, indices = torch.topk(outputs, k=3, dim=1) | |
| top3 = [] | |
| for i in range(3): | |
| idx = indices[0][i].item() | |
| prob = probs[0][i].item() | |
| top3.append({ | |
| 'class': IDX_TO_CLASS.get(idx, "Unknown"), | |
| 'confidence': prob | |
| }) | |
| return jsonify({ | |
| 'class': breed_name, | |
| 'confidence': confidence, | |
| 'top3': top3, | |
| 'success': True | |
| }) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return jsonify({'error': str(e)}), 500 | |
| def health(): | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'model_loaded': model is not None, | |
| 'model_type': 'MobileNetV3-Large', | |
| 'num_classes': len(CLASSES), | |
| 'input_size': IMAGE_SIZE, | |
| 'device': str(DEVICE) | |
| }) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=PORT, debug=False) |