aditya-sah commited on
Commit
1fe5a9f
Β·
verified Β·
1 Parent(s): 79fb278

fixed unexpected errors

Browse files
Files changed (1) hide show
  1. app.py +77 -35
app.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  import torchvision.transforms as transforms
6
- from PIL import Image
7
- import matplotlib.pyplot as plt
8
  import pandas as pd
9
  import gradio as gr
10
 
@@ -46,15 +45,28 @@ val_transform = transforms.Compose([
46
  ])
47
 
48
  # ----------------------------
49
- # Predict
 
 
 
 
 
 
 
 
 
 
50
  # ----------------------------
51
  def predict_image(img_path: str):
 
52
  base_name = os.path.basename(img_path)
53
  stem, _ = os.path.splitext(base_name)
54
 
 
55
  img = Image.open(img_path).convert("RGB")
56
  input_tensor = val_transform(img).unsqueeze(0).to(device)
57
 
 
58
  with torch.no_grad():
59
  outputs = model(input_tensor)
60
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
@@ -63,14 +75,41 @@ def predict_image(img_path: str):
63
  conf = float(top_prob.item()) * 100.0
64
  predicted_breed = breeds[int(top_idx.item())]
65
 
66
- # Annotate image with title
67
- fig, ax = plt.subplots()
68
- ax.imshow(img)
69
- ax.set_title(f"{predicted_breed} ({conf:.2f}%)", fontsize=14, fontweight="bold")
70
- ax.axis("off")
71
- annotated_name = f"{predicted_breed}_{conf:.2f}pct_{stem}.png"
72
- plt.savefig(annotated_name, format="png", bbox_inches="tight", pad_inches=0.1, dpi=150)
73
- plt.close(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # CSV output
76
  df = pd.DataFrame([{
@@ -81,43 +120,46 @@ def predict_image(img_path: str):
81
  csv_name = f"{stem}_prediction.csv"
82
  df.to_csv(csv_name, index=False)
83
 
84
- return predicted_breed, f"{conf:.2f}", annotated_name, annotated_name, csv_name
 
85
 
86
  # ----------------------------
87
- # UI (Modern Layout, Fixed Downloads)
88
  # ----------------------------
89
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
- gr.Markdown("<h1 style='text-align:center;'>πŸ„ GoVed AI - Indian Cattle/Buffalo Breed Detection</h1>")
91
- gr.Markdown(
92
- "<p style='text-align:center;'>Upload an image β†’ Detect the breed β†’ View prediction confidence β†’ Download results as CSV or Annotated Image.</p>"
93
- )
 
94
 
95
- with gr.Row():
96
- with gr.Column(scale=1):
97
- img_input = gr.Image(type="filepath", label="Upload Cattle/Buffalo Image", height=300)
98
- submit_btn = gr.Button("πŸ” Detect Breed", elem_id="detect-btn")
99
-
100
- with gr.Column(scale=2):
101
- breed_output = gr.Textbox(label="Predicted Breed", interactive=False)
102
- confidence_output = gr.Label(label="Prediction Confidence (%)")
103
 
104
  with gr.Row():
105
  with gr.Column(scale=1):
106
- img_preview = gr.Image(type="filepath", label="Annotated Image Preview", height=300)
107
- img_download = gr.File(label="⬇️ Download Annotated Image", type="file")
108
-
109
  with gr.Column(scale=1):
110
- csv_download = gr.File(label="⬇️ Download CSV", type="file")
111
-
112
- # Button Click Logic
113
- submit_btn.click(
 
 
 
 
 
 
 
114
  fn=predict_image,
115
  inputs=img_input,
116
- outputs=[breed_output, confidence_output, img_preview, img_download, csv_download],
117
  )
118
 
119
  # ----------------------------
120
- # Run
121
  # ----------------------------
122
  if __name__ == "__main__":
123
  demo.launch()
 
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  import torchvision.transforms as transforms
6
+ from PIL import Image, ImageDraw, ImageFont
 
7
  import pandas as pd
8
  import gradio as gr
9
 
 
45
  ])
46
 
47
  # ----------------------------
48
+ # Helper: safe font load
49
+ # ----------------------------
50
+ def _get_font(size):
51
+ try:
52
+ # common font on many systems; may exist in HF images
53
+ return ImageFont.truetype("DejaVuSans-Bold.ttf", size=size)
54
+ except Exception:
55
+ return ImageFont.load_default()
56
+
57
+ # ----------------------------
58
+ # Predict function (returns: breed, confidence_float, preview_path, download_image_path, csv_path)
59
  # ----------------------------
60
  def predict_image(img_path: str):
61
+ # base filename
62
  base_name = os.path.basename(img_path)
63
  stem, _ = os.path.splitext(base_name)
64
 
65
+ # open and preprocess image
66
  img = Image.open(img_path).convert("RGB")
67
  input_tensor = val_transform(img).unsqueeze(0).to(device)
68
 
69
+ # inference
70
  with torch.no_grad():
71
  outputs = model(input_tensor)
72
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
 
75
  conf = float(top_prob.item()) * 100.0
76
  predicted_breed = breeds[int(top_idx.item())]
77
 
78
+ # annotate image using PIL (avoid matplotlib dependency)
79
+ annotated = img.copy()
80
+ draw = ImageDraw.Draw(annotated)
81
+
82
+ # choose font size relative to image width
83
+ font_size = max(16, int(annotated.width * 0.04))
84
+ font = _get_font(font_size)
85
+
86
+ text = f"{predicted_breed} ({conf:.2f}%)"
87
+ text_w, text_h = draw.textsize(text, font=font)
88
+
89
+ # rectangle background for text (semi-opaque)
90
+ padding_x = 12
91
+ padding_y = 8
92
+ rect_w = text_w + padding_x * 2
93
+ rect_h = text_h + padding_y * 2
94
+
95
+ # place rectangle centered at top
96
+ rect_x0 = max(0, (annotated.width - rect_w) // 2)
97
+ rect_y0 = 0
98
+ rect_x1 = rect_x0 + rect_w
99
+ rect_y1 = rect_y0 + rect_h
100
+
101
+ # draw rectangle (black)
102
+ draw.rectangle([(rect_x0, rect_y0), (rect_x1, rect_y1)], fill=(0, 0, 0, 200))
103
+ # draw text (white) centered
104
+ text_x = rect_x0 + padding_x
105
+ text_y = rect_y0 + padding_y // 2
106
+ draw.text((text_x, text_y), text, fill=(255, 255, 255), font=font)
107
+
108
+ # Save annotated image (unique filename)
109
+ # sanitize breed name for filename
110
+ safe_breed = "".join(c if c.isalnum() or c in (' ', '_', '-') else '_' for c in predicted_breed).replace(' ', '_')
111
+ annotated_name = f"{safe_breed}_{conf:.2f}pct_{stem}.png"
112
+ annotated.save(annotated_name)
113
 
114
  # CSV output
115
  df = pd.DataFrame([{
 
120
  csv_name = f"{stem}_prediction.csv"
121
  df.to_csv(csv_name, index=False)
122
 
123
+ # Return values: breed string, confidence (float), preview image path, file path for image download, csv path
124
+ return predicted_breed, round(conf, 2), annotated_name, annotated_name, csv_name
125
 
126
  # ----------------------------
127
+ # UI: Gradio Blocks - clean layout
128
  # ----------------------------
129
+ css = """
130
+ /* center buttons and give nicer spacing */
131
+ #detect-btn { width: 100%; }
132
+ .gradio-container { max-width: 1100px; margin: auto; }
133
+ .gr-box { min-height: auto !important; }
134
+ """
135
 
136
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
137
+ gr.Markdown("<h1 style='text-align:center; margin-bottom:0.2rem;'>πŸ„ GoVed AI β€” Indian Cattle/Buffalo Breed Detection</h1>")
138
+ gr.Markdown("<p style='text-align:center; color:var(--muted); margin-top:0.2rem;'>Upload an image β†’ click Detect β†’ see breed, confidence, annotated preview and download results.</p>")
 
 
 
 
 
139
 
140
  with gr.Row():
141
  with gr.Column(scale=1):
142
+ img_input = gr.Image(type="filepath", label="πŸ“€ Upload Cattle/Buffalo Image", height=340)
143
+ detect_btn = gr.Button("πŸ” Detect Breed", elem_id="detect-btn", variant="primary")
 
144
  with gr.Column(scale=1):
145
+ breed_output = gr.Textbox(label="Predicted Breed", interactive=False)
146
+ confidence_output = gr.Number(label="Prediction Confidence (%)", interactive=False, precision=2)
147
+ # preview of annotated image
148
+ annotated_preview = gr.Image(type="filepath", label="Annotated Image Preview", height=340)
149
+ # download buttons (use type='filepath' so gradio serves the saved file)
150
+ with gr.Row():
151
+ image_download = gr.File(label="⬇️ Download Annotated Image", type="filepath")
152
+ csv_download = gr.File(label="⬇️ Download CSV", type="filepath")
153
+
154
+ # Hook up the button
155
+ detect_btn.click(
156
  fn=predict_image,
157
  inputs=img_input,
158
+ outputs=[breed_output, confidence_output, annotated_preview, image_download, csv_download],
159
  )
160
 
161
  # ----------------------------
162
+ # Launch
163
  # ----------------------------
164
  if __name__ == "__main__":
165
  demo.launch()