Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import datetime | |
| import torch.nn.functional as F | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import ViTForImageClassification | |
| from huggingface_hub import hf_hub_download | |
| from werkzeug.utils import secure_filename | |
| from pymongo import MongoClient | |
| import warnings | |
| # ===================== | |
| # Silence HF warnings | |
| # ===================== | |
| warnings.filterwarnings("ignore") | |
| # ===================== | |
| # Flask App | |
| # ===================== | |
| app = Flask(__name__) | |
| CORS(app) | |
| app.config["UPLOAD_FOLDER"] = "uploads" | |
| os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True) | |
| # ===================== | |
| # MongoDB | |
| # ===================== | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| client = MongoClient(MONGO_URI) | |
| db = client["skin-disease-db"] | |
| reports = db["reports"] | |
| # ===================== | |
| # Labels (ORDER MUST MATCH TRAINING) | |
| # ===================== | |
| labels = [ | |
| "Acne and Rosacea Photos", | |
| "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions", | |
| "Atopic Dermatitis Photos", | |
| "Bullous Disease Photos", | |
| "Cellulitis Impetigo and other Bacterial Infections", | |
| "Eczema Photos", | |
| "Exanthems and Drug Eruptions", | |
| "Hair Loss Photos Alopecia and other Hair Diseases", | |
| "Herpes HPV and other STDs Photos", | |
| "Light Diseases and Disorders of Pigmentation", | |
| "Lupus and other Connective Tissue diseases", | |
| "Melanoma Skin Cancer Nevi and Moles", | |
| "Nail Fungus and other Nail Disease", | |
| "Poison Ivy Photos and other Contact Dermatitis", | |
| "Psoriasis pictures Lichen Planus and related diseases", | |
| "Scabies Lyme Disease and other Infestations and Bites", | |
| "Seborrheic Keratoses and other Benign Tumors", | |
| "Systemic Disease", | |
| "Tinea Ringworm Candidiasis and other Fungal Infections", | |
| "Urticaria Hives", | |
| "Vascular Tumors", | |
| "Vasculitis Photos", | |
| "Warts Molluscum and other Viral Infections" | |
| ] | |
| NUM_CLASSES = len(labels) | |
| device = torch.device("cpu") | |
| # ===================== | |
| # Load trained model | |
| # ===================== | |
| weights_path = hf_hub_download( | |
| repo_id="pragun3669/dermify-vit", | |
| filename="best_vit1_model.pth" | |
| ) | |
| model = ViTForImageClassification.from_pretrained( | |
| "google/vit-large-patch16-224", | |
| num_labels=NUM_CLASSES, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # ✅ LOAD FULL TRAINED STATE (INCLUDING CLASSIFIER) | |
| state_dict = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| # ===================== | |
| # Image Transform (MATCH TRAINING) | |
| # ===================== | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5] | |
| ) | |
| ]) | |
| # ===================== | |
| # Prediction Route | |
| # ===================== | |
| def predict(): | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded"}), 400 | |
| file = request.files["file"] | |
| if file.filename == "": | |
| return jsonify({"error": "Empty file"}), 400 | |
| filename = secure_filename(file.filename) | |
| file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename) | |
| file.save(file_path) | |
| image = Image.open(file_path).convert("RGB") | |
| tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor).logits | |
| probs = F.softmax(logits, dim=1) | |
| idx = probs.argmax(dim=1).item() | |
| confidence = probs[0][idx].item() | |
| reports.insert_one({ | |
| "prediction": labels[idx], | |
| "confidence": round(confidence * 100, 2), | |
| "createdAt": datetime.datetime.utcnow() | |
| }) | |
| return jsonify({ | |
| "prediction": labels[idx], | |
| "confidence": round(confidence * 100, 2) | |
| }) | |
| # ===================== | |
| # Run (HF Spaces) | |
| # ===================== | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |