from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import torch import timm import json from PIL import Image import io from torchvision import transforms from huggingface_hub import hf_hub_download app = FastAPI() # CORS عشان الـ React يقدر يكلم الـ API app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # تحميل الموديل REPO_ID = "abdallah110/plant-disease-model" device = torch.device("cpu") print("⏳ Loading model...") model_path = hf_hub_download(repo_id=REPO_ID, filename="final_model.pth") class_names_path = hf_hub_download(repo_id=REPO_ID, filename="class_names.json") with open(class_names_path) as f: class_names = json.load(f) model = timm.create_model("convnext_tiny", pretrained=False, num_classes=len(class_names)) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() print("✅ Model loaded!") # Transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) @app.get("/") def root(): return {"status": "ok", "message": "Plant Disease API is running 🌿"} @app.post("/predict") async def predict(file: UploadFile = File(...)): contents = await file.read() img = Image.open(io.BytesIO(contents)).convert("RGB") tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(tensor) probs = torch.softmax(outputs, dim=1)[0] top5 = torch.topk(probs, 5) results = [ { "class": class_names[idx.item()], "confidence": round(prob.item() * 100, 2) } for prob, idx in zip(top5.values, top5.indices) ] return { "prediction": results[0]["class"], "confidence": results[0]["confidence"], "top5": results }