| from fastapi import FastAPI, UploadFile, File |
| from fastapi.responses import JSONResponse |
| from PIL import Image |
| import torch |
| import torchvision.transforms as transforms |
| from torchvision.models import resnet50, ResNet50_Weights |
| import io |
| from timeit import default_timer as timer |
|
|
| app = FastAPI(title="SkinVision FastAPI") |
|
|
| |
| class_names = ['Acne', 'Carcinoma', 'Eczema', 'Keratosis', 'Milia', 'Rosacea', 'Clear'] |
|
|
| |
| weights = ResNet50_Weights.DEFAULT |
| model = resnet50(weights=weights) |
| model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) |
| model.load_state_dict(torch.load("pretrained_model.pth", map_location=torch.device("cpu"))) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| symptoms_dict = { |
| "Acne": "Symptoms include whiteheads, blackheads, pimples, nodules, and cysts on the face, chest, or back.", |
| "Eczema": "Symptoms include dry, itchy, inflamed skin, often with red or brown patches and occasional oozing or crusting.", |
| "Carcinoma": "Symptoms include new or unusual growths, sores that do not heal, or changes in the appearance of moles or skin lesions.", |
| "Rosacea": "Rosacea causes facial redness, flushing, visible blood vessels, bumps, and pimples, along with burning, stinging, dryness, or swelling.", |
| "Milia": "Symptoms include small, white, hard bumps on the face, particularly around the eyes and cheeks.", |
| "Keratosis": "Symptoms include rough, scaly patches on the skin, often pink, red, or brown, with a sandpaper-like texture.", |
| "Clear": "Congrats, your skin has no syptoms that Showing off any disease" |
| } |
| causes_dict = { |
| "Acne": "Acne occurs when clogged hair follicles trap sebum and bacteria. Hormonal changes, stress, greasy products, and diet can worsen it.", |
| "Eczema": "Eczema is caused by an overactive immune system, genetics, and environmental triggers like soaps or allergens.", |
| "Carcinoma": "Caused by genetic mutations often from UV radiation, toxins, or infections like HPV. Risk increases with exposure and genetics.", |
| "Rosacea": "Exact cause is unknown; triggers include sun, stress, spicy food, and immune or blood vessel issues.", |
| "Milia": "Caused by trapped dead skin cells forming cysts. Can also arise from skin damage or steroid cream overuse.", |
| "Keratosis": "Seborrheic keratoses may be influenced by aging and genetics. They're harmless but can be removed for cosmetic reasons.", |
| "Clear": "Congrats, your skin has no causes that Showing off any disease" |
| } |
| treatments_dict = { |
| "Acne": "Use benzoyl peroxide, salicylic acid, or retinoids. Severe cases may require oral antibiotics or isotretinoin.", |
| "Eczema": "Moisturize regularly, avoid irritants, and use prescribed steroid creams or antihistamines during flare-ups.", |
| "Carcinoma": "Treatments include surgery, chemotherapy, radiation, or targeted therapy. Consult a specialist urgently.", |
| "Rosacea": "Treat with gentle skin care, antibiotics, or laser therapy. Avoid known personal triggers.", |
| "Milia": "Often resolves naturally. Removal can be done by a dermatologist using topical retinoids or minor surgery.", |
| "Keratosis": "Usually harmless. Treatment options include cryotherapy, laser therapy, or surgical removal.", |
| "Clear": "Congrats, your skin looks clear!" |
| } |
|
|
| @app.post("/predict") |
| async def predict(file: UploadFile = File(...)): |
| try: |
| start = timer() |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| input_tensor = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| outputs = model(input_tensor) |
| probs = torch.softmax(outputs[0], dim=0) |
| predictions = {class_names[i]: float(probs[i]) for i in range(len(class_names))} |
| top_label = max(predictions, key=predictions.get) |
|
|
| result = { |
| "prediction": { |
| "label": top_label, |
| "confidence": round(predictions[top_label], 4) |
| }, |
| "all_probabilities": predictions, |
| "symptoms": symptoms_dict.get(top_label), |
| "causes": causes_dict.get(top_label), |
| "treatments": treatments_dict.get(top_label), |
| "prediction_time_seconds": round(timer() - start, 4) |
| } |
|
|
| return JSONResponse(result) |
|
|
| except Exception as e: |
| return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|