File size: 4,639 Bytes
6113a36
a50b16e
 
 
6113a36
d1ea74e
a50b16e
6113a36
 
 
 
 
fd7a884
6113a36
d1ea74e
 
 
 
 
 
6113a36
d1ea74e
 
6113a36
 
 
 
 
d1ea74e
6113a36
 
 
 
 
 
 
fd7a884
6113a36
 
8935f2a
 
 
 
 
 
fd7a884
8935f2a
d1ea74e
 
 
 
 
 
 
f2d6004
d1ea74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8935f2a
d1ea74e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import io
from timeit import default_timer as timer

app = FastAPI(title="SkinVision FastAPI")

# Class names
class_names = ['Acne', 'Carcinoma', 'Eczema', 'Keratosis', 'Milia', 'Rosacea', 'Clear']

# Load model with fixed cache path and updated API
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load("pretrained_model.pth", map_location=torch.device("cpu")))
model.eval()

# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Explanation dictionaries
symptoms_dict = {
    "Acne": "Symptoms include whiteheads, blackheads, pimples, nodules, and cysts on the face, chest, or back.",
    "Eczema": "Symptoms include dry, itchy, inflamed skin, often with red or brown patches and occasional oozing or crusting.",
    "Carcinoma": "Symptoms include new or unusual growths, sores that do not heal, or changes in the appearance of moles or skin lesions.",
    "Rosacea": "Rosacea causes facial redness, flushing, visible blood vessels, bumps, and pimples, along with burning, stinging, dryness, or swelling.",
    "Milia": "Symptoms include small, white, hard bumps on the face, particularly around the eyes and cheeks.",
    "Keratosis": "Symptoms include rough, scaly patches on the skin, often pink, red, or brown, with a sandpaper-like texture.",
    "Clear": "Congrats, your skin has no syptoms that Showing off any disease"
}
causes_dict = {
    "Acne": "Acne occurs when clogged hair follicles trap sebum and bacteria. Hormonal changes, stress, greasy products, and diet can worsen it.",
    "Eczema": "Eczema is caused by an overactive immune system, genetics, and environmental triggers like soaps or allergens.",
    "Carcinoma": "Caused by genetic mutations often from UV radiation, toxins, or infections like HPV. Risk increases with exposure and genetics.",
    "Rosacea": "Exact cause is unknown; triggers include sun, stress, spicy food, and immune or blood vessel issues.",
    "Milia": "Caused by trapped dead skin cells forming cysts. Can also arise from skin damage or steroid cream overuse.",
    "Keratosis": "Seborrheic keratoses may be influenced by aging and genetics. They're harmless but can be removed for cosmetic reasons.",
    "Clear": "Congrats, your skin has no causes that Showing off any disease"
}
treatments_dict = {
    "Acne": "Use benzoyl peroxide, salicylic acid, or retinoids. Severe cases may require oral antibiotics or isotretinoin.",
    "Eczema": "Moisturize regularly, avoid irritants, and use prescribed steroid creams or antihistamines during flare-ups.",
    "Carcinoma": "Treatments include surgery, chemotherapy, radiation, or targeted therapy. Consult a specialist urgently.",
    "Rosacea": "Treat with gentle skin care, antibiotics, or laser therapy. Avoid known personal triggers.",
    "Milia": "Often resolves naturally. Removal can be done by a dermatologist using topical retinoids or minor surgery.",
    "Keratosis": "Usually harmless. Treatment options include cryotherapy, laser therapy, or surgical removal.",
    "Clear": "Congrats, your skin looks clear!"
}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        start = timer()
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        input_tensor = transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.softmax(outputs[0], dim=0)
            predictions = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
            top_label = max(predictions, key=predictions.get)

        result = {
            "prediction": {
                "label": top_label,
                "confidence": round(predictions[top_label], 4)
            },
            "all_probabilities": predictions,
            "symptoms": symptoms_dict.get(top_label),
            "causes": causes_dict.get(top_label),
            "treatments": treatments_dict.get(top_label),
            "prediction_time_seconds": round(timer() - start, 4)
        }

        return JSONResponse(result)

    except Exception as e:
        return JSONResponse(status_code=500, content={"error": str(e)})