Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import joblib | |
| import os | |
| # β Label maps | |
| label_map = { | |
| 'G': 0, 'DR': 1, 'WD': 2, 'ND': 3, 'OTHER': 4 | |
| } | |
| reverse_label_map = {v: k for k, v in label_map.items()} | |
| # β οΈ Adjust this based on cluster-to-severity relation (see notes below) | |
| severity_map = { | |
| 0: 'Low', # e.g., yellow cluster | |
| 1: 'Medium', # e.g., pink cluster | |
| 2: 'High' # e.g., blue cluster | |
| } | |
| # β Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = models.resnet18(weights=None) | |
| model.fc = nn.Linear(model.fc.in_features, 5) | |
| model.load_state_dict(torch.load("crop_damage_resnet18.pth", map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| # β Load PCA and KMeans | |
| pca = joblib.load("pca.pkl") | |
| kmeans = joblib.load("kmeans.pkl") | |
| # β Transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # β Prediction function | |
| def predict(img): | |
| try: | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| # π Predict damage class | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| _, predicted = torch.max(output, 1) | |
| damage_class = reverse_label_map[predicted.item()] | |
| # π Feature extraction for severity | |
| x = model.conv1(img_tensor) | |
| x = model.bn1(x) | |
| x = model.relu(x) | |
| x = model.maxpool(x) | |
| x = model.layer1(x) | |
| x = model.layer2(x) | |
| x = model.layer3(x) | |
| x = model.layer4(x) | |
| x = model.avgpool(x) | |
| x = torch.flatten(x, 1).detach().cpu().numpy() | |
| # π PCA + Clustering | |
| x_pca = pca.transform(x) | |
| severity_cluster = kmeans.predict(x_pca)[0] | |
| print(f"Predicted Severity Cluster: {severity_cluster}") # Debug log | |
| severity_label = severity_map.get(severity_cluster, f"Cluster {severity_cluster}") | |
| return f"Damage: {damage_class}", f"Severity: {severity_label}" | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| # β Gradio Interface | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=["text", "text"], | |
| title="Crop Damage & Severity Detector", | |
| description="Upload a crop image to classify damage type and estimate severity." | |
| ).launch() | |