File size: 4,639 Bytes
6113a36 a50b16e 6113a36 d1ea74e a50b16e 6113a36 fd7a884 6113a36 d1ea74e 6113a36 d1ea74e 6113a36 d1ea74e 6113a36 fd7a884 6113a36 8935f2a fd7a884 8935f2a d1ea74e f2d6004 d1ea74e 8935f2a d1ea74e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import io
from timeit import default_timer as timer
app = FastAPI(title="SkinVision FastAPI")
# Class names
class_names = ['Acne', 'Carcinoma', 'Eczema', 'Keratosis', 'Milia', 'Rosacea', 'Clear']
# Load model with fixed cache path and updated API
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load("pretrained_model.pth", map_location=torch.device("cpu")))
model.eval()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Explanation dictionaries
symptoms_dict = {
"Acne": "Symptoms include whiteheads, blackheads, pimples, nodules, and cysts on the face, chest, or back.",
"Eczema": "Symptoms include dry, itchy, inflamed skin, often with red or brown patches and occasional oozing or crusting.",
"Carcinoma": "Symptoms include new or unusual growths, sores that do not heal, or changes in the appearance of moles or skin lesions.",
"Rosacea": "Rosacea causes facial redness, flushing, visible blood vessels, bumps, and pimples, along with burning, stinging, dryness, or swelling.",
"Milia": "Symptoms include small, white, hard bumps on the face, particularly around the eyes and cheeks.",
"Keratosis": "Symptoms include rough, scaly patches on the skin, often pink, red, or brown, with a sandpaper-like texture.",
"Clear": "Congrats, your skin has no syptoms that Showing off any disease"
}
causes_dict = {
"Acne": "Acne occurs when clogged hair follicles trap sebum and bacteria. Hormonal changes, stress, greasy products, and diet can worsen it.",
"Eczema": "Eczema is caused by an overactive immune system, genetics, and environmental triggers like soaps or allergens.",
"Carcinoma": "Caused by genetic mutations often from UV radiation, toxins, or infections like HPV. Risk increases with exposure and genetics.",
"Rosacea": "Exact cause is unknown; triggers include sun, stress, spicy food, and immune or blood vessel issues.",
"Milia": "Caused by trapped dead skin cells forming cysts. Can also arise from skin damage or steroid cream overuse.",
"Keratosis": "Seborrheic keratoses may be influenced by aging and genetics. They're harmless but can be removed for cosmetic reasons.",
"Clear": "Congrats, your skin has no causes that Showing off any disease"
}
treatments_dict = {
"Acne": "Use benzoyl peroxide, salicylic acid, or retinoids. Severe cases may require oral antibiotics or isotretinoin.",
"Eczema": "Moisturize regularly, avoid irritants, and use prescribed steroid creams or antihistamines during flare-ups.",
"Carcinoma": "Treatments include surgery, chemotherapy, radiation, or targeted therapy. Consult a specialist urgently.",
"Rosacea": "Treat with gentle skin care, antibiotics, or laser therapy. Avoid known personal triggers.",
"Milia": "Often resolves naturally. Removal can be done by a dermatologist using topical retinoids or minor surgery.",
"Keratosis": "Usually harmless. Treatment options include cryotherapy, laser therapy, or surgical removal.",
"Clear": "Congrats, your skin looks clear!"
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
start = timer()
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.softmax(outputs[0], dim=0)
predictions = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
top_label = max(predictions, key=predictions.get)
result = {
"prediction": {
"label": top_label,
"confidence": round(predictions[top_label], 4)
},
"all_probabilities": predictions,
"symptoms": symptoms_dict.get(top_label),
"causes": causes_dict.get(top_label),
"treatments": treatments_dict.get(top_label),
"prediction_time_seconds": round(timer() - start, 4)
}
return JSONResponse(result)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
|