Derma / app.py
sheikh987's picture
Update app.py
f2d6004 verified
Raw
History Blame Contribute Delete
4.64 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import io
from timeit import default_timer as timer
app = FastAPI(title="SkinVision FastAPI")
# Class names
class_names = ['Acne', 'Carcinoma', 'Eczema', 'Keratosis', 'Milia', 'Rosacea', 'Clear']
# Load model with fixed cache path and updated API
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load("pretrained_model.pth", map_location=torch.device("cpu")))
model.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Explanation dictionaries
symptoms_dict = {
"Acne": "Symptoms include whiteheads, blackheads, pimples, nodules, and cysts on the face, chest, or back.",
"Eczema": "Symptoms include dry, itchy, inflamed skin, often with red or brown patches and occasional oozing or crusting.",
"Carcinoma": "Symptoms include new or unusual growths, sores that do not heal, or changes in the appearance of moles or skin lesions.",
"Rosacea": "Rosacea causes facial redness, flushing, visible blood vessels, bumps, and pimples, along with burning, stinging, dryness, or swelling.",
"Milia": "Symptoms include small, white, hard bumps on the face, particularly around the eyes and cheeks.",
"Keratosis": "Symptoms include rough, scaly patches on the skin, often pink, red, or brown, with a sandpaper-like texture.",
"Clear": "Congrats, your skin has no syptoms that Showing off any disease"
}
causes_dict = {
"Acne": "Acne occurs when clogged hair follicles trap sebum and bacteria. Hormonal changes, stress, greasy products, and diet can worsen it.",
"Eczema": "Eczema is caused by an overactive immune system, genetics, and environmental triggers like soaps or allergens.",
"Carcinoma": "Caused by genetic mutations often from UV radiation, toxins, or infections like HPV. Risk increases with exposure and genetics.",
"Rosacea": "Exact cause is unknown; triggers include sun, stress, spicy food, and immune or blood vessel issues.",
"Milia": "Caused by trapped dead skin cells forming cysts. Can also arise from skin damage or steroid cream overuse.",
"Keratosis": "Seborrheic keratoses may be influenced by aging and genetics. They're harmless but can be removed for cosmetic reasons.",
"Clear": "Congrats, your skin has no causes that Showing off any disease"
}
treatments_dict = {
"Acne": "Use benzoyl peroxide, salicylic acid, or retinoids. Severe cases may require oral antibiotics or isotretinoin.",
"Eczema": "Moisturize regularly, avoid irritants, and use prescribed steroid creams or antihistamines during flare-ups.",
"Carcinoma": "Treatments include surgery, chemotherapy, radiation, or targeted therapy. Consult a specialist urgently.",
"Rosacea": "Treat with gentle skin care, antibiotics, or laser therapy. Avoid known personal triggers.",
"Milia": "Often resolves naturally. Removal can be done by a dermatologist using topical retinoids or minor surgery.",
"Keratosis": "Usually harmless. Treatment options include cryotherapy, laser therapy, or surgical removal.",
"Clear": "Congrats, your skin looks clear!"
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
start = timer()
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.softmax(outputs[0], dim=0)
predictions = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
top_label = max(predictions, key=predictions.get)
result = {
"prediction": {
"label": top_label,
"confidence": round(predictions[top_label], 4)
},
"all_probabilities": predictions,
"symptoms": symptoms_dict.get(top_label),
"causes": causes_dict.get(top_label),
"treatments": treatments_dict.get(top_label),
"prediction_time_seconds": round(timer() - start, 4)
}
return JSONResponse(result)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})