iamomtiwari commited on
Commit
e27fc27
·
verified ·
1 Parent(s): e505bbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -39
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import torch
3
  from transformers import ViTForImageClassification, ViTFeatureExtractor
4
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
5
- from PIL import Image
6
 
7
  # Load models with error handling
8
  try:
@@ -11,9 +10,9 @@ try:
11
  fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
12
  fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
13
  except Exception as e:
14
- raise gr.Error(f"Failed to load models: {str(e)}. Please check model names and internet connection.")
15
 
16
- # Define class labels with treatment advice
17
  class_labels = {
18
  1: {"label": "Stage Corn Common Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
19
  2: {"label": "Stage Corn Gray Leaf Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
@@ -31,16 +30,15 @@ class_labels = {
31
  14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
32
  }
33
 
34
- # Create 0-indexed labels list
35
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
36
-
37
- # Confidence threshold
38
  CONFIDENCE_THRESHOLD = 0.5
39
 
40
  def predict(image):
41
  try:
42
- # First, try ViT model
43
- inputs = feature_extractor(images=image, return_tensors="pt")
 
 
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
  logits = outputs.logits
@@ -49,43 +47,57 @@ def predict(image):
49
  confidence = confidences[0, predicted_class_idx].item()
50
 
51
  if confidence < CONFIDENCE_THRESHOLD:
52
- # Fallback to ResNet-50
53
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
54
  with torch.no_grad():
55
  outputs_fallback = fallback_model(**inputs_fallback)
56
- logits_fallback = outputs_fallback.logits
57
- confidences_fallback = torch.softmax(logits_fallback, dim=-1)
58
- predicted_class_idx_fallback = logits_fallback.argmax(-1).item()
59
- fallback_confidence = confidences_fallback[0, predicted_class_idx_fallback].item()
60
-
61
- fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
62
- return (
63
- f"Low confidence in ViT model ({confidence * 100:.2f}%).\n"
64
- f"ResNet-50 predicts: {fallback_label} ({fallback_confidence * 100:.2f}%).\n\n"
65
- "If this doesn't match your input, try another image."
66
- )
67
 
68
  predicted_label = labels_list[predicted_class_idx]
69
- treatment_advice = class_labels[predicted_class_idx + 1]["treatment"]
70
- return (
71
- f"Disease: {predicted_label} ({confidence * 100:.2f}%)\n\n"
72
- f"Treatment Advice: {treatment_advice}"
73
- )
74
 
75
  except Exception as e:
76
- return f"Error processing image: {str(e)}. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Gradio Interface
79
- demo = gr.Interface(
80
- fn=predict,
81
- inputs=gr.Image(type="pil"),
82
- outputs="text",
83
- title="🌱 Crop Disease Detection",
84
- description="Upload a crop plant image to detect diseases. Uses ViT + ResNet-50 fallback.",
85
- examples=[
86
- ["corn_rust_example.jpg"], # Replace with real examples
87
- ["wheat_healthy_example.jpg"]
88
- ]
89
- )
90
 
91
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from transformers import ViTForImageClassification, ViTFeatureExtractor
4
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
 
5
 
6
  # Load models with error handling
7
  try:
 
10
  fallback_model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
11
  fallback_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
12
  except Exception as e:
13
+ raise gr.Error(f"Model loading failed: {str(e)}")
14
 
15
+ # Class labels and treatments (truncated for brevity)
16
  class_labels = {
17
  1: {"label": "Stage Corn Common Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
18
  2: {"label": "Stage Corn Gray Leaf Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
 
30
  14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
31
  }
32
 
 
33
  labels_list = [class_labels[i]["label"] for i in range(1, 15)]
 
 
34
  CONFIDENCE_THRESHOLD = 0.5
35
 
36
  def predict(image):
37
  try:
38
+ if not isinstance(image, torch.Tensor):
39
+ # Convert image to tensor if needed
40
+ inputs = feature_extractor(images=image, return_tensors="pt")
41
+
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
  logits = outputs.logits
 
47
  confidence = confidences[0, predicted_class_idx].item()
48
 
49
  if confidence < CONFIDENCE_THRESHOLD:
 
50
  inputs_fallback = fallback_feature_extractor(images=image, return_tensors="pt")
51
  with torch.no_grad():
52
  outputs_fallback = fallback_model(**inputs_fallback)
53
+ predicted_class_idx_fallback = outputs_fallback.logits.argmax(-1).item()
54
+ fallback_label = fallback_model.config.id2label[predicted_class_idx_fallback]
55
+ return f"Low confidence ({confidence:.2%}). Fallback prediction: {fallback_label}"
 
 
 
 
 
 
 
 
56
 
57
  predicted_label = labels_list[predicted_class_idx]
58
+ treatment = class_labels[predicted_class_idx + 1]["treatment"]
59
+ return f"Disease: {predicted_label}\n\nTreatment: {treatment}"
 
 
 
60
 
61
  except Exception as e:
62
+ return f"Error: {str(e)}"
63
+
64
+ # Create interface with explicit types
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("# 🌱 Crop Disease Detection")
67
+ gr.Markdown("Upload a crop plant image to detect diseases")
68
+
69
+ with gr.Row():
70
+ image_input = gr.Image(type="pil")
71
+ output_text = gr.Textbox()
72
+
73
+ submit_btn = gr.Button("Analyze")
74
+ submit_btn.click(
75
+ fn=predict,
76
+ inputs=image_input,
77
+ outputs=output_text
78
+ )
79
 
80
+ gr.Examples(
81
+ examples=[["example_corn.jpg"], ["example_wheat.jpg"]],
82
+ inputs=image_input
83
+ )
 
 
 
 
 
 
 
 
84
 
85
+ if __name__ == "__main__":
86
+ demo.launch()
87
+ """# Define class labels with treatment advice
88
+ class_labels = {
89
+ 1: {"label": "Stage Corn Common Rust", "treatment": "Apply fungicides as soon as symptoms are noticed. Practice crop rotation and remove infected plants."},
90
+ 2: {"label": "Stage Corn Gray Leaf Spot", "treatment": "Rotate crops to non-host plants, apply resistant varieties, and use fungicides as needed."},
91
+ 3: {"label": "Stage Safe Corn Healthy", "treatment": "Continue good agricultural practices: ensure proper irrigation, nutrient supply, and monitor for pests."},
92
+ 4: {"label": "Stage Corn Northern Leaf Blight", "treatment": "Remove and destroy infected plant debris, apply fungicides, and rotate crops."},
93
+ 5: {"label": "Stage Rice Brown Spot", "treatment": "Use resistant varieties, improve field drainage, and apply fungicides if necessary."},
94
+ 6: {"label": "Stage Safe Rice Healthy", "treatment": "Maintain proper irrigation, fertilization, and pest control measures."},
95
+ 7: {"label": "Stage Rice Leaf Blast", "treatment": "Use resistant varieties, apply fungicides during high-risk periods, and practice good field management."},
96
+ 8: {"label": "Stage Rice Neck Blast", "treatment": "Plant resistant varieties, improve nutrient management, and apply fungicides if symptoms appear."},
97
+ 9: {"label": "Stage Sugarcane Bacterial Blight", "treatment": "Use disease-free planting material, practice crop rotation, and destroy infected plants."},
98
+ 10: {"label": "Stage Safe Sugarcane Healthy", "treatment": "Maintain healthy soil conditions and proper irrigation."},
99
+ 11: {"label": "Stage Sugarcane Red Rot", "treatment": "Plant resistant varieties and ensure good drainage."},
100
+ 12: {"label": "Stage Wheat Brown Rust", "treatment": "Apply fungicides and practice crop rotation with non-host crops."},
101
+ 13: {"label": "Stage Safe Wheat Healthy", "treatment": "Continue with good management practices, including proper fertilization and weed control."},
102
+ 14: {"label": "Stage Wheat Yellow Rust", "treatment": "Use resistant varieties, apply fungicides, and rotate crops."}
103
+ }"""