Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,8 +4,6 @@ import torch.nn as nn
|
|
| 4 |
from torchvision import transforms
|
| 5 |
from torchvision.models import swin_t
|
| 6 |
from PIL import Image
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
import io
|
| 9 |
|
| 10 |
# π§ Model definition
|
| 11 |
class MMIM(nn.Module):
|
|
@@ -31,7 +29,7 @@ model.load_state_dict(torch.load("MMIM_best.pth", map_location=device))
|
|
| 31 |
model.to(device)
|
| 32 |
model.eval()
|
| 33 |
|
| 34 |
-
# β
|
| 35 |
class_names = [
|
| 36 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
|
| 37 |
"Prickly acacia", "Rubber vine", "Siam weed", "Snake weed"
|
|
@@ -43,57 +41,32 @@ transform = transforms.Compose([
|
|
| 43 |
transforms.ToTensor()
|
| 44 |
])
|
| 45 |
|
| 46 |
-
# π Prediction with
|
| 47 |
def predict(img):
|
| 48 |
img = img.convert('RGB')
|
| 49 |
img_tensor = transform(img).unsqueeze(0).to(device)
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
outputs = model(img_tensor)
|
| 53 |
-
probs = torch.softmax(outputs, dim=1)
|
|
|
|
| 54 |
|
| 55 |
-
conf, pred = torch.max(probs, 0)
|
| 56 |
predicted_class = class_names[pred.item()]
|
| 57 |
confidence = conf.item() * 100
|
| 58 |
|
| 59 |
-
# β
Main prediction message
|
| 60 |
if predicted_class.lower() == "negative":
|
| 61 |
-
|
| 62 |
-
else:
|
| 63 |
-
message = f"β
Predicted class: **{predicted_class}**\nConfidence: **{confidence:.2f}%**"
|
| 64 |
|
| 65 |
-
|
| 66 |
-
top_probs, top_idxs = torch.topk(probs, 3)
|
| 67 |
-
top_classes = [class_names[i] for i in top_idxs]
|
| 68 |
-
top_confidences = [p.item() * 100 for p in top_probs]
|
| 69 |
|
| 70 |
-
|
| 71 |
-
ax.barh(top_classes[::-1], top_confidences[::-1], color='skyblue')
|
| 72 |
-
ax.set_xlim(0, 100)
|
| 73 |
-
ax.set_xlabel("Confidence (%)")
|
| 74 |
-
ax.set_title("Top 3 Predictions")
|
| 75 |
-
plt.tight_layout()
|
| 76 |
-
|
| 77 |
-
buf = io.BytesIO()
|
| 78 |
-
plt.savefig(buf, format="png")
|
| 79 |
-
plt.close(fig)
|
| 80 |
-
buf.seek(0)
|
| 81 |
-
|
| 82 |
-
from PIL import Image as PILImage
|
| 83 |
-
chart_img = PILImage.open(buf)
|
| 84 |
-
|
| 85 |
-
return message, chart_img
|
| 86 |
-
|
| 87 |
-
# π¨ Gradio UI
|
| 88 |
interface = gr.Interface(
|
| 89 |
fn=predict,
|
| 90 |
inputs=gr.Image(type="pil"),
|
| 91 |
-
outputs=
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
],
|
| 95 |
-
title="Weed Image Classifier - Debug Mode",
|
| 96 |
-
description="π§ͺ Debug view enabled: see prediction confidence and top 3 predicted classes."
|
| 97 |
)
|
| 98 |
|
| 99 |
interface.launch()
|
|
|
|
|
|
| 4 |
from torchvision import transforms
|
| 5 |
from torchvision.models import swin_t
|
| 6 |
from PIL import Image
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# π§ Model definition
|
| 9 |
class MMIM(nn.Module):
|
|
|
|
| 29 |
model.to(device)
|
| 30 |
model.eval()
|
| 31 |
|
| 32 |
+
# β
Updated class names (match folder structure)
|
| 33 |
class_names = [
|
| 34 |
"Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
|
| 35 |
"Prickly acacia", "Rubber vine", "Siam weed", "Snake weed"
|
|
|
|
| 41 |
transforms.ToTensor()
|
| 42 |
])
|
| 43 |
|
| 44 |
+
# π Prediction function with negative detection
|
| 45 |
def predict(img):
|
| 46 |
img = img.convert('RGB')
|
| 47 |
img_tensor = transform(img).unsqueeze(0).to(device)
|
| 48 |
|
| 49 |
with torch.no_grad():
|
| 50 |
outputs = model(img_tensor)
|
| 51 |
+
probs = torch.softmax(outputs, dim=1)
|
| 52 |
+
conf, pred = torch.max(probs, 1)
|
| 53 |
|
|
|
|
| 54 |
predicted_class = class_names[pred.item()]
|
| 55 |
confidence = conf.item() * 100
|
| 56 |
|
|
|
|
| 57 |
if predicted_class.lower() == "negative":
|
| 58 |
+
return f"β οΈ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
return f"β
Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
# π¨ Gradio Interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
interface = gr.Interface(
|
| 64 |
fn=predict,
|
| 65 |
inputs=gr.Image(type="pil"),
|
| 66 |
+
outputs="text",
|
| 67 |
+
title="Weed Image Classifier",
|
| 68 |
+
description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
|
| 71 |
interface.launch()
|
| 72 |
+
remove the confidence level
|