Spaces:
Sleeping
Sleeping
fixed unexpected errors
Browse files
app.py
CHANGED
|
@@ -3,8 +3,7 @@ import torch
|
|
| 3 |
import torch.nn as nn
|
| 4 |
import torchvision.models as models
|
| 5 |
import torchvision.transforms as transforms
|
| 6 |
-
from PIL import Image
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
import pandas as pd
|
| 9 |
import gradio as gr
|
| 10 |
|
|
@@ -46,15 +45,28 @@ val_transform = transforms.Compose([
|
|
| 46 |
])
|
| 47 |
|
| 48 |
# ----------------------------
|
| 49 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# ----------------------------
|
| 51 |
def predict_image(img_path: str):
|
|
|
|
| 52 |
base_name = os.path.basename(img_path)
|
| 53 |
stem, _ = os.path.splitext(base_name)
|
| 54 |
|
|
|
|
| 55 |
img = Image.open(img_path).convert("RGB")
|
| 56 |
input_tensor = val_transform(img).unsqueeze(0).to(device)
|
| 57 |
|
|
|
|
| 58 |
with torch.no_grad():
|
| 59 |
outputs = model(input_tensor)
|
| 60 |
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
|
|
@@ -63,14 +75,41 @@ def predict_image(img_path: str):
|
|
| 63 |
conf = float(top_prob.item()) * 100.0
|
| 64 |
predicted_breed = breeds[int(top_idx.item())]
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# CSV output
|
| 76 |
df = pd.DataFrame([{
|
|
@@ -81,43 +120,46 @@ def predict_image(img_path: str):
|
|
| 81 |
csv_name = f"{stem}_prediction.csv"
|
| 82 |
df.to_csv(csv_name, index=False)
|
| 83 |
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
# ----------------------------
|
| 87 |
-
# UI
|
| 88 |
# ----------------------------
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
submit_btn = gr.Button("π Detect Breed", elem_id="detect-btn")
|
| 99 |
-
|
| 100 |
-
with gr.Column(scale=2):
|
| 101 |
-
breed_output = gr.Textbox(label="Predicted Breed", interactive=False)
|
| 102 |
-
confidence_output = gr.Label(label="Prediction Confidence (%)")
|
| 103 |
|
| 104 |
with gr.Row():
|
| 105 |
with gr.Column(scale=1):
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
with gr.Column(scale=1):
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
fn=predict_image,
|
| 115 |
inputs=img_input,
|
| 116 |
-
outputs=[breed_output, confidence_output,
|
| 117 |
)
|
| 118 |
|
| 119 |
# ----------------------------
|
| 120 |
-
#
|
| 121 |
# ----------------------------
|
| 122 |
if __name__ == "__main__":
|
| 123 |
demo.launch()
|
|
|
|
| 3 |
import torch.nn as nn
|
| 4 |
import torchvision.models as models
|
| 5 |
import torchvision.transforms as transforms
|
| 6 |
+
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import gradio as gr
|
| 9 |
|
|
|
|
| 45 |
])
|
| 46 |
|
| 47 |
# ----------------------------
|
| 48 |
+
# Helper: safe font load
|
| 49 |
+
# ----------------------------
|
| 50 |
+
def _get_font(size):
|
| 51 |
+
try:
|
| 52 |
+
# common font on many systems; may exist in HF images
|
| 53 |
+
return ImageFont.truetype("DejaVuSans-Bold.ttf", size=size)
|
| 54 |
+
except Exception:
|
| 55 |
+
return ImageFont.load_default()
|
| 56 |
+
|
| 57 |
+
# ----------------------------
|
| 58 |
+
# Predict function (returns: breed, confidence_float, preview_path, download_image_path, csv_path)
|
| 59 |
# ----------------------------
|
| 60 |
def predict_image(img_path: str):
|
| 61 |
+
# base filename
|
| 62 |
base_name = os.path.basename(img_path)
|
| 63 |
stem, _ = os.path.splitext(base_name)
|
| 64 |
|
| 65 |
+
# open and preprocess image
|
| 66 |
img = Image.open(img_path).convert("RGB")
|
| 67 |
input_tensor = val_transform(img).unsqueeze(0).to(device)
|
| 68 |
|
| 69 |
+
# inference
|
| 70 |
with torch.no_grad():
|
| 71 |
outputs = model(input_tensor)
|
| 72 |
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
|
|
|
|
| 75 |
conf = float(top_prob.item()) * 100.0
|
| 76 |
predicted_breed = breeds[int(top_idx.item())]
|
| 77 |
|
| 78 |
+
# annotate image using PIL (avoid matplotlib dependency)
|
| 79 |
+
annotated = img.copy()
|
| 80 |
+
draw = ImageDraw.Draw(annotated)
|
| 81 |
+
|
| 82 |
+
# choose font size relative to image width
|
| 83 |
+
font_size = max(16, int(annotated.width * 0.04))
|
| 84 |
+
font = _get_font(font_size)
|
| 85 |
+
|
| 86 |
+
text = f"{predicted_breed} ({conf:.2f}%)"
|
| 87 |
+
text_w, text_h = draw.textsize(text, font=font)
|
| 88 |
+
|
| 89 |
+
# rectangle background for text (semi-opaque)
|
| 90 |
+
padding_x = 12
|
| 91 |
+
padding_y = 8
|
| 92 |
+
rect_w = text_w + padding_x * 2
|
| 93 |
+
rect_h = text_h + padding_y * 2
|
| 94 |
+
|
| 95 |
+
# place rectangle centered at top
|
| 96 |
+
rect_x0 = max(0, (annotated.width - rect_w) // 2)
|
| 97 |
+
rect_y0 = 0
|
| 98 |
+
rect_x1 = rect_x0 + rect_w
|
| 99 |
+
rect_y1 = rect_y0 + rect_h
|
| 100 |
+
|
| 101 |
+
# draw rectangle (black)
|
| 102 |
+
draw.rectangle([(rect_x0, rect_y0), (rect_x1, rect_y1)], fill=(0, 0, 0, 200))
|
| 103 |
+
# draw text (white) centered
|
| 104 |
+
text_x = rect_x0 + padding_x
|
| 105 |
+
text_y = rect_y0 + padding_y // 2
|
| 106 |
+
draw.text((text_x, text_y), text, fill=(255, 255, 255), font=font)
|
| 107 |
+
|
| 108 |
+
# Save annotated image (unique filename)
|
| 109 |
+
# sanitize breed name for filename
|
| 110 |
+
safe_breed = "".join(c if c.isalnum() or c in (' ', '_', '-') else '_' for c in predicted_breed).replace(' ', '_')
|
| 111 |
+
annotated_name = f"{safe_breed}_{conf:.2f}pct_{stem}.png"
|
| 112 |
+
annotated.save(annotated_name)
|
| 113 |
|
| 114 |
# CSV output
|
| 115 |
df = pd.DataFrame([{
|
|
|
|
| 120 |
csv_name = f"{stem}_prediction.csv"
|
| 121 |
df.to_csv(csv_name, index=False)
|
| 122 |
|
| 123 |
+
# Return values: breed string, confidence (float), preview image path, file path for image download, csv path
|
| 124 |
+
return predicted_breed, round(conf, 2), annotated_name, annotated_name, csv_name
|
| 125 |
|
| 126 |
# ----------------------------
|
| 127 |
+
# UI: Gradio Blocks - clean layout
|
| 128 |
# ----------------------------
|
| 129 |
+
css = """
|
| 130 |
+
/* center buttons and give nicer spacing */
|
| 131 |
+
#detect-btn { width: 100%; }
|
| 132 |
+
.gradio-container { max-width: 1100px; margin: auto; }
|
| 133 |
+
.gr-box { min-height: auto !important; }
|
| 134 |
+
"""
|
| 135 |
|
| 136 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 137 |
+
gr.Markdown("<h1 style='text-align:center; margin-bottom:0.2rem;'>π GoVed AI β Indian Cattle/Buffalo Breed Detection</h1>")
|
| 138 |
+
gr.Markdown("<p style='text-align:center; color:var(--muted); margin-top:0.2rem;'>Upload an image β click Detect β see breed, confidence, annotated preview and download results.</p>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
with gr.Row():
|
| 141 |
with gr.Column(scale=1):
|
| 142 |
+
img_input = gr.Image(type="filepath", label="π€ Upload Cattle/Buffalo Image", height=340)
|
| 143 |
+
detect_btn = gr.Button("π Detect Breed", elem_id="detect-btn", variant="primary")
|
|
|
|
| 144 |
with gr.Column(scale=1):
|
| 145 |
+
breed_output = gr.Textbox(label="Predicted Breed", interactive=False)
|
| 146 |
+
confidence_output = gr.Number(label="Prediction Confidence (%)", interactive=False, precision=2)
|
| 147 |
+
# preview of annotated image
|
| 148 |
+
annotated_preview = gr.Image(type="filepath", label="Annotated Image Preview", height=340)
|
| 149 |
+
# download buttons (use type='filepath' so gradio serves the saved file)
|
| 150 |
+
with gr.Row():
|
| 151 |
+
image_download = gr.File(label="β¬οΈ Download Annotated Image", type="filepath")
|
| 152 |
+
csv_download = gr.File(label="β¬οΈ Download CSV", type="filepath")
|
| 153 |
+
|
| 154 |
+
# Hook up the button
|
| 155 |
+
detect_btn.click(
|
| 156 |
fn=predict_image,
|
| 157 |
inputs=img_input,
|
| 158 |
+
outputs=[breed_output, confidence_output, annotated_preview, image_download, csv_download],
|
| 159 |
)
|
| 160 |
|
| 161 |
# ----------------------------
|
| 162 |
+
# Launch
|
| 163 |
# ----------------------------
|
| 164 |
if __name__ == "__main__":
|
| 165 |
demo.launch()
|