CNNN / app.py
abdallah110's picture
Upload app.py
08d4318 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import timm
import json
from PIL import Image
import io
from torchvision import transforms
from huggingface_hub import hf_hub_download
app = FastAPI()
# CORS ุนุดุงู† ุงู„ู€ React ูŠู‚ุฏุฑ ูŠูƒู„ู… ุงู„ู€ API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„
REPO_ID = "abdallah110/plant-disease-model"
device = torch.device("cpu")
print("โณ Loading model...")
model_path = hf_hub_download(repo_id=REPO_ID, filename="final_model.pth")
class_names_path = hf_hub_download(repo_id=REPO_ID, filename="class_names.json")
with open(class_names_path) as f:
class_names = json.load(f)
model = timm.create_model("convnext_tiny", pretrained=False, num_classes=len(class_names))
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("โœ… Model loaded!")
# Transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
@app.get("/")
def root():
return {"status": "ok", "message": "Plant Disease API is running ๐ŸŒฟ"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(tensor)
probs = torch.softmax(outputs, dim=1)[0]
top5 = torch.topk(probs, 5)
results = [
{
"class": class_names[idx.item()],
"confidence": round(prob.item() * 100, 2)
}
for prob, idx in zip(top5.values, top5.indices)
]
return {
"prediction": results[0]["class"],
"confidence": results[0]["confidence"],
"top5": results
}