from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from PIL import Image import torch import torchvision.transforms as transforms from torchvision.models import resnet50, ResNet50_Weights import io from timeit import default_timer as timer app = FastAPI(title="SkinVision FastAPI") # Class names class_names = ['Acne', 'Carcinoma', 'Eczema', 'Keratosis', 'Milia', 'Rosacea', 'Clear'] # Load model with fixed cache path and updated API weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights) model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) model.load_state_dict(torch.load("pretrained_model.pth", map_location=torch.device("cpu"))) model.eval() # Image transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Explanation dictionaries symptoms_dict = { "Acne": "Symptoms include whiteheads, blackheads, pimples, nodules, and cysts on the face, chest, or back.", "Eczema": "Symptoms include dry, itchy, inflamed skin, often with red or brown patches and occasional oozing or crusting.", "Carcinoma": "Symptoms include new or unusual growths, sores that do not heal, or changes in the appearance of moles or skin lesions.", "Rosacea": "Rosacea causes facial redness, flushing, visible blood vessels, bumps, and pimples, along with burning, stinging, dryness, or swelling.", "Milia": "Symptoms include small, white, hard bumps on the face, particularly around the eyes and cheeks.", "Keratosis": "Symptoms include rough, scaly patches on the skin, often pink, red, or brown, with a sandpaper-like texture.", "Clear": "Congrats, your skin has no syptoms that Showing off any disease" } causes_dict = { "Acne": "Acne occurs when clogged hair follicles trap sebum and bacteria. Hormonal changes, stress, greasy products, and diet can worsen it.", "Eczema": "Eczema is caused by an overactive immune system, genetics, and environmental triggers like soaps or allergens.", "Carcinoma": "Caused by genetic mutations often from UV radiation, toxins, or infections like HPV. Risk increases with exposure and genetics.", "Rosacea": "Exact cause is unknown; triggers include sun, stress, spicy food, and immune or blood vessel issues.", "Milia": "Caused by trapped dead skin cells forming cysts. Can also arise from skin damage or steroid cream overuse.", "Keratosis": "Seborrheic keratoses may be influenced by aging and genetics. They're harmless but can be removed for cosmetic reasons.", "Clear": "Congrats, your skin has no causes that Showing off any disease" } treatments_dict = { "Acne": "Use benzoyl peroxide, salicylic acid, or retinoids. Severe cases may require oral antibiotics or isotretinoin.", "Eczema": "Moisturize regularly, avoid irritants, and use prescribed steroid creams or antihistamines during flare-ups.", "Carcinoma": "Treatments include surgery, chemotherapy, radiation, or targeted therapy. Consult a specialist urgently.", "Rosacea": "Treat with gentle skin care, antibiotics, or laser therapy. Avoid known personal triggers.", "Milia": "Often resolves naturally. Removal can be done by a dermatologist using topical retinoids or minor surgery.", "Keratosis": "Usually harmless. Treatment options include cryotherapy, laser therapy, or surgical removal.", "Clear": "Congrats, your skin looks clear!" } @app.post("/predict") async def predict(file: UploadFile = File(...)): try: start = timer() image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(input_tensor) probs = torch.softmax(outputs[0], dim=0) predictions = {class_names[i]: float(probs[i]) for i in range(len(class_names))} top_label = max(predictions, key=predictions.get) result = { "prediction": { "label": top_label, "confidence": round(predictions[top_label], 4) }, "all_probabilities": predictions, "symptoms": symptoms_dict.get(top_label), "causes": causes_dict.get(top_label), "treatments": treatments_dict.get(top_label), "prediction_time_seconds": round(timer() - start, 4) } return JSONResponse(result) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)})