aditya-sah commited on
Commit
d6cf8a7
Β·
verified Β·
1 Parent(s): a9665b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -79
app.py CHANGED
@@ -1,15 +1,21 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torchvision.models as models
4
  import torchvision.transforms as transforms
5
- from PIL import Image, ImageDraw, ImageFont
 
 
6
  import gradio as gr
7
- import csv
8
- import os
9
 
10
- # -------------------------------
11
- # Load Breeds
12
- # -------------------------------
 
 
 
 
 
13
  breeds = [
14
  "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss",
15
  "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian",
@@ -18,101 +24,96 @@ breeds = [
18
  "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi",
19
  "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur"
20
  ]
21
-
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
-
24
- # -------------------------------
25
- # Load Model
26
- # -------------------------------
27
  num_classes = len(breeds)
 
 
 
 
28
  model = models.efficientnet_b0(weights=None)
29
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
30
- model.load_state_dict(torch.load("bovine_model.pth", map_location=device))
31
- model.to(device)
32
- model.eval()
33
-
34
- # -------------------------------
35
- # Image Transform
36
- # -------------------------------
37
- transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
39
  transforms.ToTensor(),
40
- transforms.Normalize([0.485, 0.456, 0.406],
41
  [0.229, 0.224, 0.225])
42
  ])
43
 
44
- # -------------------------------
45
- # Prediction Function
46
- # -------------------------------
47
- def predict(img):
48
- img = img.convert("RGB")
49
- input_tensor = transform(img).unsqueeze(0).to(device)
 
 
 
50
 
51
  with torch.no_grad():
52
  outputs = model(input_tensor)
53
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
54
 
55
  top_prob, top_idx = torch.max(probs, dim=0)
56
- predicted_breed = breeds[top_idx]
57
- confidence = top_prob.item() * 100
58
-
59
- # -------------------------------
60
- # Annotate image with prediction
61
- # -------------------------------
62
- annotated_img = img.copy()
63
- draw = ImageDraw.Draw(annotated_img)
64
- font = ImageFont.load_default()
65
-
66
- text = f"{predicted_breed} ({confidence:.2f}%)"
67
- text_w, text_h = draw.textsize(text, font=font)
68
-
69
- # Draw rectangle background
70
- draw.rectangle([(0, 0), (text_w + 10, text_h + 10)], fill="black")
71
- draw.text((5, 5), text, fill="white", font=font)
72
-
73
- # Save annotated image
74
- img_filename = f"{predicted_breed}_{int(confidence)}.png"
75
- annotated_img.save(img_filename)
76
-
77
- # Save CSV file
78
- csv_filename = "prediction.csv"
79
- with open(csv_filename, "w", newline="") as f:
80
- writer = csv.writer(f)
81
- writer.writerow(["Breed", "Confidence"])
82
- writer.writerow([predicted_breed, f"{confidence:.2f}%"])
83
-
84
- return (
85
- f"{predicted_breed} ({confidence:.2f}%)",
86
- annotated_img,
87
- img_filename,
88
- csv_filename
89
- )
90
-
91
- # -------------------------------
92
- # Gradio Interface
93
- # -------------------------------
94
- with gr.Blocks(theme="default") as demo:
95
- gr.Markdown("<h1 style='text-align:center;'>GoVed AI πŸ„</h1>")
96
  gr.Markdown(
97
- "Upload a cattle image to detect the breed, view prediction confidence, "
98
- "and download results as **annotated image + CSV**."
99
  )
100
 
101
  with gr.Row():
102
- with gr.Column():
103
- img_input = gr.Image(type="pil", label="Upload Cattle Image")
104
- submit_btn = gr.Button("πŸ” Detect Breed")
105
- with gr.Column():
106
- breed_output = gr.Textbox(label="Predicted Breed")
107
- img_output = gr.Image(type="pil", label="Image with Breed Info")
108
- img_download = gr.File(label="Download Annotated Image")
109
- csv_download = gr.File(label="Download CSV File")
 
 
 
 
110
 
111
  submit_btn.click(
112
- fn=predict,
113
  inputs=img_input,
114
- outputs=[breed_output, img_output, img_download, csv_download]
115
  )
116
 
 
 
 
117
  if __name__ == "__main__":
118
  demo.launch()
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
  import gradio as gr
 
 
10
 
11
+ # ----------------------------
12
+ # Device
13
+ # ----------------------------
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # ----------------------------
17
+ # Labels
18
+ # ----------------------------
19
  breeds = [
20
  "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss",
21
  "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian",
 
24
  "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi",
25
  "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur"
26
  ]
 
 
 
 
 
 
27
  num_classes = len(breeds)
28
+
29
+ # ----------------------------
30
+ # Model
31
+ # ----------------------------
32
  model = models.efficientnet_b0(weights=None)
33
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
34
+ state = torch.load("bovine_model.pth", map_location=device)
35
+ model.load_state_dict(state)
36
+ model.to(device).eval()
37
+
38
+ # ----------------------------
39
+ # Preprocessing
40
+ # ----------------------------
41
+ val_transform = transforms.Compose([
42
  transforms.Resize((224, 224)),
43
  transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406],
45
  [0.229, 0.224, 0.225])
46
  ])
47
 
48
+ # ----------------------------
49
+ # Predict
50
+ # ----------------------------
51
+ def predict_image(img_path: str):
52
+ base_name = os.path.basename(img_path)
53
+ stem, _ = os.path.splitext(base_name)
54
+
55
+ img = Image.open(img_path).convert("RGB")
56
+ input_tensor = val_transform(img).unsqueeze(0).to(device)
57
 
58
  with torch.no_grad():
59
  outputs = model(input_tensor)
60
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
61
 
62
  top_prob, top_idx = torch.max(probs, dim=0)
63
+ conf = float(top_prob.item()) * 100.0
64
+ predicted_breed = breeds[int(top_idx.item())]
65
+
66
+ # Annotate image with title
67
+ fig, ax = plt.subplots()
68
+ ax.imshow(img)
69
+ ax.set_title(f"{predicted_breed} ({conf:.2f}%)", fontsize=14, fontweight="bold")
70
+ ax.axis("off")
71
+ annotated_name = f"{predicted_breed}_{conf:.2f}pct_{stem}.png"
72
+ plt.savefig(annotated_name, format="png", bbox_inches="tight", pad_inches=0.1, dpi=150)
73
+ plt.close(fig)
74
+
75
+ # CSV output
76
+ df = pd.DataFrame([{
77
+ "breed": predicted_breed,
78
+ "confidence_percent": f"{conf:.2f}%",
79
+ "filename": base_name
80
+ }])
81
+ csv_name = f"{stem}_prediction.csv"
82
+ df.to_csv(csv_name, index=False)
83
+
84
+ return predicted_breed, f"{conf:.2f}", annotated_name, csv_name
85
+
86
+ # ----------------------------
87
+ # UI (Modern Layout)
88
+ # ----------------------------
89
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("<h1 style='text-align:center;'>πŸ„ GoVed AI - Indian Cattle/Buffalo Breed Detection</h1>")
 
 
 
 
 
 
 
 
 
 
 
 
91
  gr.Markdown(
92
+ "<p style='text-align:center;'>Upload an image β†’ Detect the breed β†’ View prediction confidence β†’ Download results as CSV or Annotated Image.</p>"
 
93
  )
94
 
95
  with gr.Row():
96
+ with gr.Column(scale=1):
97
+ img_input = gr.Image(type="filepath", label="Upload Cattle/Buffalo Image", height=300)
98
+ submit_btn = gr.Button("πŸ” Detect Breed", elem_id="detect-btn")
99
+
100
+ with gr.Column(scale=2):
101
+ breed_output = gr.Textbox(label="Predicted Breed", interactive=False)
102
+ confidence_output = gr.Label(label="Prediction Confidence (%)")
103
+ img_output = gr.Image(type="filepath", label="Annotated Image Preview", height=300)
104
+
105
+ with gr.Row():
106
+ img_download = gr.File(label="⬇️ Download Annotated Image")
107
+ csv_download = gr.File(label="⬇️ Download CSV")
108
 
109
  submit_btn.click(
110
+ fn=predict_image,
111
  inputs=img_input,
112
+ outputs=[breed_output, confidence_output, img_output, csv_download],
113
  )
114
 
115
+ # ----------------------------
116
+ # Run
117
+ # ----------------------------
118
  if __name__ == "__main__":
119
  demo.launch()