import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import io import os from flask import Flask, request, jsonify, render_template_string from flask_cors import CORS import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 MODEL_PATH = 'Model-79.85.pth' DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_SIZE = 224 PORT = int(os.environ.get("PORT", 7860)) # HF Spaces uses 7860 def load_model(): if not os.path.exists(MODEL_PATH): logger.error(f"ERROR: Model file '{MODEL_PATH}' not found!") return None, {} try: checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) model = models.mobilenet_v3_large(weights=None) in_features = 960 model.classifier = nn.Sequential( nn.Linear(in_features, 1280), nn.Hardswish(inplace=True), nn.Dropout(0.3, inplace=True), nn.Linear(1280, 640), nn.Hardswish(inplace=True), nn.Dropout(0.2, inplace=True), nn.Linear(640, 120) ) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() idx_to_class = checkpoint.get('idx_to_class', {}) class_names = {int(k): v for k, v in idx_to_class.items()} val_acc = checkpoint.get('best_top1', checkpoint.get('best_acc', 'N/A')) if isinstance(val_acc, (int, float)): logger.info(f"Model loaded! Classes: {len(class_names)}, Best Val Acc: {val_acc:.2f}%") else: logger.info(f"Model loaded! Classes: {len(class_names)}") return model, class_names except Exception as e: logger.error(f"Error loading model: {e}") import traceback logger.error(traceback.format_exc()) return None, {} model, IDX_TO_CLASS = load_model() CLASSES = [IDX_TO_CLASS[i] for i in range(len(IDX_TO_CLASS))] if IDX_TO_CLASS else [] transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_with_tta(image, model, device): model.eval() transform_flip = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) orig_tensor = transform(image).unsqueeze(0).to(device) flip_tensor = transform_flip(image).unsqueeze(0).to(device) with torch.no_grad(): outputs_orig = torch.nn.functional.softmax(model(orig_tensor), dim=1) outputs_flip = torch.nn.functional.softmax(model(flip_tensor), dim=1) outputs = (outputs_orig + outputs_flip) / 2 confidence, predicted = torch.max(outputs, 1) return predicted.item(), float(confidence.item()) HTML_TEMPLATE = '''