File size: 3,625 Bytes
365f668
3e6d9dd
 
 
 
365f668
 
d6cf8a7
3e6d9dd
 
d6cf8a7
 
 
 
 
 
 
 
3e6d9dd
 
 
 
 
 
 
 
 
d6cf8a7
 
 
 
3e6d9dd
 
d6cf8a7
 
 
 
 
 
 
 
3e6d9dd
 
d6cf8a7
3e6d9dd
 
 
1fe5a9f
fa4119e
d6cf8a7
 
365f668
d6cf8a7
 
 
 
 
3e6d9dd
 
 
 
 
 
d6cf8a7
 
 
365f668
 
 
 
 
459a033
365f668
 
d6cf8a7
459a033
d6cf8a7
 
 
 
 
 
 
 
365f668
 
 
 
 
 
459a033
d6cf8a7
 
fa4119e
d6cf8a7
459a033
 
365f668
459a033
 
365f668
 
 
459a033
365f668
 
459a033
3e6d9dd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os, io
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import gradio as gr

# ----------------------------
# Device
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# ----------------------------
# Labels
# ----------------------------
breeds = [
    "Alambadi", "Amritmahal", "Ayrshire", "Banni", "Bargur", "Bhadawari", "Brown_Swiss",
    "Dangi", "Deoni", "Gir", "Guernsey", "Hallikar", "Hariana", "Holstein_Friesian",
    "Jaffrabadi", "Jersey", "Kangayam", "Kankrej", "Kasargod", "Kenkatha", "Kherigarh",
    "Khillari", "Krishna_Valley", "Malnad_gidda", "Mehsana", "Murrah", "Nagori", "Nagpuri",
    "Nili_Ravi", "Nimari", "Ongole", "Pulikulam", "Rathi", "Red_Dane", "Red_Sindhi",
    "Sahiwal", "Surti", "Tharparkar", "Toda", "Umblachery", "Vechur"
]
num_classes = len(breeds)

# ----------------------------
# Model
# ----------------------------
model = models.efficientnet_b0(weights=None)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
state = torch.load("bovine_model.pth", map_location=device)
model.load_state_dict(state)
model.to(device).eval()

# ----------------------------
# Preprocessing
# ----------------------------
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ----------------------------
# Predict
# ----------------------------
def predict_image(img_path: str):
    # Keep original filename for outputs
    base_name = os.path.basename(img_path)
    stem, _ = os.path.splitext(base_name)

    img = Image.open(img_path).convert("RGB")
    input_tensor = val_transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]

    top_prob, top_idx = torch.max(probs, dim=0)
    conf = float(top_prob.item()) * 100.0
    predicted_breed = breeds[int(top_idx.item())]

    # Annotate image (title overlay)
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.set_title(f"{predicted_breed} ({conf:.2f}%)")
    ax.axis("off")
    annotated_name = f"{predicted_breed}_{conf:.2f}pct_{stem}.png"
    plt.savefig(annotated_name, format="png", bbox_inches="tight", pad_inches=0.1, dpi=150)
    plt.close(fig)

    # CSV output
    df = pd.DataFrame([{
        "breed": predicted_breed,
        "confidence_percent": f"{conf:.2f}%",
        "filename": base_name
    }])
    csv_name = f"{stem}_prediction.csv"
    df.to_csv(csv_name, index=False)

    # Return:
    # 1) predicted breed (text)
    # 2) confidence (%) (text)
    # 3) file (CSV)
    # 4) file (annotated image with breed+confidence in filename)
    return predicted_breed, f"{conf:.2f}%", csv_name, annotated_name


# ----------------------------
# UI
# ----------------------------
demo = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="filepath", label="Upload Cattle/Buffalo Image"),
    outputs=[
        gr.Textbox(label="Predicted Breed"),
        gr.Textbox(label="Prediction Confidence (%)"),
        gr.File(label="Download CSV"),
        gr.File(label="Download Annotated Image")
    ],
    title="Indian Cattle/Buffalo Breed Detection",
    description="Upload an image → get predicted breed, confidence score, CSV, and an annotated image file named with the predicted breed and confidence."
)

if __name__ == "__main__":
    demo.launch()