dermify-ml / app.py
pragun3669's picture
Update app.py
6528f19 verified
import os
import torch
import datetime
import torch.nn.functional as F
from flask import Flask, request, jsonify
from flask_cors import CORS
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification
from huggingface_hub import hf_hub_download
from werkzeug.utils import secure_filename
from pymongo import MongoClient
import warnings
# =====================
# Silence HF warnings
# =====================
warnings.filterwarnings("ignore")
# =====================
# Flask App
# =====================
app = Flask(__name__)
CORS(app)
app.config["UPLOAD_FOLDER"] = "uploads"
os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
# =====================
# MongoDB
# =====================
MONGO_URI = os.getenv("MONGO_URI")
client = MongoClient(MONGO_URI)
db = client["skin-disease-db"]
reports = db["reports"]
# =====================
# Labels (ORDER MUST MATCH TRAINING)
# =====================
labels = [
"Acne and Rosacea Photos",
"Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
"Atopic Dermatitis Photos",
"Bullous Disease Photos",
"Cellulitis Impetigo and other Bacterial Infections",
"Eczema Photos",
"Exanthems and Drug Eruptions",
"Hair Loss Photos Alopecia and other Hair Diseases",
"Herpes HPV and other STDs Photos",
"Light Diseases and Disorders of Pigmentation",
"Lupus and other Connective Tissue diseases",
"Melanoma Skin Cancer Nevi and Moles",
"Nail Fungus and other Nail Disease",
"Poison Ivy Photos and other Contact Dermatitis",
"Psoriasis pictures Lichen Planus and related diseases",
"Scabies Lyme Disease and other Infestations and Bites",
"Seborrheic Keratoses and other Benign Tumors",
"Systemic Disease",
"Tinea Ringworm Candidiasis and other Fungal Infections",
"Urticaria Hives",
"Vascular Tumors",
"Vasculitis Photos",
"Warts Molluscum and other Viral Infections"
]
NUM_CLASSES = len(labels)
device = torch.device("cpu")
# =====================
# Load trained model
# =====================
weights_path = hf_hub_download(
repo_id="pragun3669/dermify-vit",
filename="best_vit1_model.pth"
)
model = ViTForImageClassification.from_pretrained(
"google/vit-large-patch16-224",
num_labels=NUM_CLASSES,
ignore_mismatched_sizes=True
)
# ✅ LOAD FULL TRAINED STATE (INCLUDING CLASSIFIER)
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# =====================
# Image Transform (MATCH TRAINING)
# =====================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
# =====================
# Prediction Route
# =====================
@app.route("/predict", methods=["POST"])
def predict():
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "Empty file"}), 400
filename = secure_filename(file.filename)
file_path = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(file_path)
image = Image.open(file_path).convert("RGB")
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor).logits
probs = F.softmax(logits, dim=1)
idx = probs.argmax(dim=1).item()
confidence = probs[0][idx].item()
reports.insert_one({
"prediction": labels[idx],
"confidence": round(confidence * 100, 2),
"createdAt": datetime.datetime.utcnow()
})
return jsonify({
"prediction": labels[idx],
"confidence": round(confidence * 100, 2)
})
# =====================
# Run (HF Spaces)
# =====================
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)