import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import io import os from flask import Flask, request, jsonify, render_template_string from flask_cors import CORS import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 MODEL_PATH = 'Model-79.85.pth' DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_SIZE = 224 PORT = int(os.environ.get("PORT", 7860)) # HF Spaces uses 7860 def load_model(): if not os.path.exists(MODEL_PATH): logger.error(f"ERROR: Model file '{MODEL_PATH}' not found!") return None, {} try: checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) model = models.mobilenet_v3_large(weights=None) in_features = 960 model.classifier = nn.Sequential( nn.Linear(in_features, 1280), nn.Hardswish(inplace=True), nn.Dropout(0.3, inplace=True), nn.Linear(1280, 640), nn.Hardswish(inplace=True), nn.Dropout(0.2, inplace=True), nn.Linear(640, 120) ) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() idx_to_class = checkpoint.get('idx_to_class', {}) class_names = {int(k): v for k, v in idx_to_class.items()} val_acc = checkpoint.get('best_top1', checkpoint.get('best_acc', 'N/A')) if isinstance(val_acc, (int, float)): logger.info(f"Model loaded! Classes: {len(class_names)}, Best Val Acc: {val_acc:.2f}%") else: logger.info(f"Model loaded! Classes: {len(class_names)}") return model, class_names except Exception as e: logger.error(f"Error loading model: {e}") import traceback logger.error(traceback.format_exc()) return None, {} model, IDX_TO_CLASS = load_model() CLASSES = [IDX_TO_CLASS[i] for i in range(len(IDX_TO_CLASS))] if IDX_TO_CLASS else [] transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_with_tta(image, model, device): model.eval() transform_flip = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) orig_tensor = transform(image).unsqueeze(0).to(device) flip_tensor = transform_flip(image).unsqueeze(0).to(device) with torch.no_grad(): outputs_orig = torch.nn.functional.softmax(model(orig_tensor), dim=1) outputs_flip = torch.nn.functional.softmax(model(flip_tensor), dim=1) outputs = (outputs_orig + outputs_flip) / 2 confidence, predicted = torch.max(outputs, 1) return predicted.item(), float(confidence.item()) HTML_TEMPLATE = ''' Dog Breed Classifier
{{ '✓ Model Ready' if model_loaded else '✗ Model Not Loaded' }}

Dog Breed Classifier

Selected dog
Identifying breed...
''' @app.route('/') def index(): return render_template_string(HTML_TEMPLATE, model_loaded=model is not None) @app.route('/predict', methods=['POST']) def predict(): if model is None: return jsonify({'error': 'Model not loaded'}), 500 if 'image' not in request.files: return jsonify({'error': 'No image provided'}), 400 file = request.files['image'] if file.filename == '': return jsonify({'error': 'No file selected'}), 400 try: image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)).convert('RGB') pred_idx, confidence = predict_with_tta(image, model, DEVICE) breed_name = IDX_TO_CLASS.get(pred_idx, "Unknown") with torch.no_grad(): tensor = transform(image).unsqueeze(0).to(DEVICE) outputs = torch.nn.functional.softmax(model(tensor), dim=1) probs, indices = torch.topk(outputs, k=3, dim=1) top3 = [] for i in range(3): idx = indices[0][i].item() prob = probs[0][i].item() top3.append({ 'class': IDX_TO_CLASS.get(idx, "Unknown"), 'confidence': prob }) return jsonify({ 'class': breed_name, 'confidence': confidence, 'top3': top3, 'success': True }) except Exception as e: logger.error(f"Prediction error: {e}") import traceback logger.error(traceback.format_exc()) return jsonify({'error': str(e)}), 500 @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'healthy', 'model_loaded': model is not None, 'model_type': 'MobileNetV3-Large', 'num_classes': len(CLASSES), 'input_size': IMAGE_SIZE, 'device': str(DEVICE) }) if __name__ == '__main__': app.run(host='0.0.0.0', port=PORT, debug=False)