aditya-sah commited on
Commit
365f668
·
verified ·
1 Parent(s): af448bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -28
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import os
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  import torchvision.transforms as transforms
6
- from PIL import Image, ImageDraw, ImageFont
 
7
  import pandas as pd
8
  import gradio as gr
9
 
@@ -48,6 +49,7 @@ val_transform = transforms.Compose([
48
  # Predict
49
  # ----------------------------
50
  def predict_image(img_path: str):
 
51
  base_name = os.path.basename(img_path)
52
  stem, _ = os.path.splitext(base_name)
53
 
@@ -62,26 +64,14 @@ def predict_image(img_path: str):
62
  conf = float(top_prob.item()) * 100.0
63
  predicted_breed = breeds[int(top_idx.item())]
64
 
65
- # Create a new image with extra space on top for text
66
- extra_height = 60
67
- new_img = Image.new("RGB", (img.width, img.height + extra_height), "white")
68
- new_img.paste(img, (0, extra_height))
69
-
70
- # Draw text at top-center
71
- draw = ImageDraw.Draw(new_img)
72
- try:
73
- font = ImageFont.truetype("arial.ttf", 28) # Bold font if available
74
- except:
75
- font = ImageFont.load_default()
76
-
77
- text = f"{predicted_breed} ({conf:.2f}%)"
78
- text_width, text_height = draw.textsize(text, font=font)
79
- x = (new_img.width - text_width) // 2
80
- y = (extra_height - text_height) // 2
81
- draw.text((x, y), text, fill="black", font=font)
82
-
83
  annotated_name = f"{predicted_breed}_{conf:.2f}pct_{stem}.png"
84
- new_img.save(annotated_name)
 
85
 
86
  # CSV output
87
  df = pd.DataFrame([{
@@ -92,7 +82,12 @@ def predict_image(img_path: str):
92
  csv_name = f"{stem}_prediction.csv"
93
  df.to_csv(csv_name, index=False)
94
 
95
- return predicted_breed, f"{conf:.2f}%", annotated_name, csv_name
 
 
 
 
 
96
 
97
 
98
  # ----------------------------
@@ -100,15 +95,15 @@ def predict_image(img_path: str):
100
  # ----------------------------
101
  demo = gr.Interface(
102
  fn=predict_image,
103
- inputs=gr.Image(type="filepath", label="📤 Upload Cattle/Buffalo Image"),
104
  outputs=[
105
  gr.Textbox(label="Predicted Breed"),
106
- gr.Textbox(label="Confidence (%)"),
107
- gr.File(label="⬇️ Download Annotated Image", type="filepath"),
108
- gr.File(label="⬇️ Download Prediction CSV", type="filepath"),
109
  ],
110
- title="🐄 GoVed AI – Indian Cattle/Buffalo Breed Detection",
111
- description="Upload an image → Get predicted breed, confidence, annotated image, and CSV download."
112
  )
113
 
114
  if __name__ == "__main__":
 
1
+ import os, io
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
  import pandas as pd
9
  import gradio as gr
10
 
 
49
  # Predict
50
  # ----------------------------
51
  def predict_image(img_path: str):
52
+ # Keep original filename for outputs
53
  base_name = os.path.basename(img_path)
54
  stem, _ = os.path.splitext(base_name)
55
 
 
64
  conf = float(top_prob.item()) * 100.0
65
  predicted_breed = breeds[int(top_idx.item())]
66
 
67
+ # Annotate image (title overlay)
68
+ fig, ax = plt.subplots()
69
+ ax.imshow(img)
70
+ ax.set_title(f"{predicted_breed} ({conf:.2f}%)")
71
+ ax.axis("off")
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  annotated_name = f"{predicted_breed}_{conf:.2f}pct_{stem}.png"
73
+ plt.savefig(annotated_name, format="png", bbox_inches="tight", pad_inches=0.1, dpi=150)
74
+ plt.close(fig)
75
 
76
  # CSV output
77
  df = pd.DataFrame([{
 
82
  csv_name = f"{stem}_prediction.csv"
83
  df.to_csv(csv_name, index=False)
84
 
85
+ # Return:
86
+ # 1) predicted breed (text)
87
+ # 2) confidence (%) (text)
88
+ # 3) file (CSV)
89
+ # 4) file (annotated image with breed+confidence in filename)
90
+ return predicted_breed, f"{conf:.2f}%", csv_name, annotated_name
91
 
92
 
93
  # ----------------------------
 
95
  # ----------------------------
96
  demo = gr.Interface(
97
  fn=predict_image,
98
+ inputs=gr.Image(type="filepath", label="Upload Cattle/Buffalo Image"),
99
  outputs=[
100
  gr.Textbox(label="Predicted Breed"),
101
+ gr.Textbox(label="Prediction Confidence (%)"),
102
+ gr.File(label="Download CSV"),
103
+ gr.File(label="Download Annotated Image")
104
  ],
105
+ title="Indian Cattle/Buffalo Breed Detection",
106
+ description="Upload an image → get predicted breed, confidence score, CSV, and an annotated image file named with the predicted breed and confidence."
107
  )
108
 
109
  if __name__ == "__main__":