Cropdamage / app.py
Venisri2006's picture
Update app.py
974baaa verified
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import gradio as gr
import joblib
import os
# βœ… Label maps
label_map = {
'G': 0, 'DR': 1, 'WD': 2, 'ND': 3, 'OTHER': 4
}
reverse_label_map = {v: k for k, v in label_map.items()}
# ⚠️ Adjust this based on cluster-to-severity relation (see notes below)
severity_map = {
0: 'Low', # e.g., yellow cluster
1: 'Medium', # e.g., pink cluster
2: 'High' # e.g., blue cluster
}
# βœ… Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 5)
model.load_state_dict(torch.load("crop_damage_resnet18.pth", map_location=device))
model = model.to(device)
model.eval()
# βœ… Load PCA and KMeans
pca = joblib.load("pca.pkl")
kmeans = joblib.load("kmeans.pkl")
# βœ… Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# βœ… Prediction function
def predict(img):
try:
img_tensor = transform(img).unsqueeze(0).to(device)
# πŸ” Predict damage class
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
damage_class = reverse_label_map[predicted.item()]
# πŸ” Feature extraction for severity
x = model.conv1(img_tensor)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
x = model.layer1(x)
x = model.layer2(x)
x = model.layer3(x)
x = model.layer4(x)
x = model.avgpool(x)
x = torch.flatten(x, 1).detach().cpu().numpy()
# πŸ” PCA + Clustering
x_pca = pca.transform(x)
severity_cluster = kmeans.predict(x_pca)[0]
print(f"Predicted Severity Cluster: {severity_cluster}") # Debug log
severity_label = severity_map.get(severity_cluster, f"Cluster {severity_cluster}")
return f"Damage: {damage_class}", f"Severity: {severity_label}"
except Exception as e:
return f"Error: {str(e)}", ""
# βœ… Gradio Interface
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=["text", "text"],
title="Crop Damage & Severity Detector",
description="Upload a crop image to classify damage type and estimate severity."
).launch()