SavlonBhai's picture
Update app.py
6b57a1b verified
# app.py
import os
import json
import pickle
import joblib
import warnings
from typing import Tuple, Dict
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tv_models
import torchvision.transforms as T
import numpy as np
from PIL import Image
warnings.filterwarnings("ignore")
# Optional timm
try:
import timm
HAS_TIMM = True
except Exception:
HAS_TIMM = False
# ---------------------------
# Defaults & metadata
# ---------------------------
DEFAULT_CLASSES = [
"Ayrshire cattle","Brown Swiss cattle","Holstein Friesian cattle",
"Jaffrabadi","Jersey cattle","Murrah","Red Dane cattle",
"kankarej","sahiwal","sahiwal cross","sibbi"
]
BREED_INFO = {
"Ayrshire cattle":{"type":"Dairy Cow","origin":"Scotland",
"characteristics":"Strong, adaptable, excellent udder conformation and superior grazing ability",
"milk_yield":"6000-7000 liters per lactation",
"special_features":"Red and white patches, hardy in cold weather, high butterfat content",
"weight":"450-550 kg","height":"125-135 cm","temperament":"Docile and friendly"},
"Brown Swiss cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Switzerland",
"characteristics":"Docile, strong, excellent for cheese production, disease resistant",
"milk_yield":"10000-14000 liters per lactation",
"special_features":"Light to dark brown color with creamy white muzzle, exceptional longevity",
"weight":"600-700 kg","height":"135-150 cm","temperament":"Calm and intelligent"},
"Holstein Friesian cattle":{"type":"Dairy Cow","origin":"Netherlands/Germany",
"characteristics":"Highest milk production, excellent feed conversion, docile temperament",
"milk_yield":"8000-12000 liters per lactation",
"special_features":"Distinctive black and white patches, large frame, heat sensitive",
"weight":"580-700 kg","height":"140-150 cm","temperament":"Gentle and manageable"},
"Jaffrabadi":{"type":"Indigenous Dairy Buffalo","origin":"Gujarat, India (Saurashtra region)",
"characteristics":"Heaviest Indian buffalo breed, adapted to harsh semi-arid conditions",
"milk_yield":"2000-2500 liters per lactation",
"special_features":"Black color, dome-shaped forehead, ring-like horns, highest butterfat content",
"weight":"400-600 kg","height":"130-140 cm","temperament":"Hardy and resilient"},
"Jersey cattle":{"type":"Dairy Cow","origin":"Jersey, Channel Islands",
"characteristics":"Efficient feed conversion, calving ease, heat tolerant, docile",
"milk_yield":"4500-6500 liters per lactation",
"special_features":"Light tan to fawn color, smallest dairy breed, highest butterfat percentage",
"weight":"350-450 kg","height":"120-125 cm","temperament":"Alert and intelligent"},
"Murrah":{"type":"Indigenous Dairy Buffalo","origin":"Haryana and Punjab, India",
"characteristics":"Highest milk yielding buffalo breed, docile nature, good mothers",
"milk_yield":"2200-3000 liters per lactation",
"special_features":"Jet black color, tightly curved horns, compact body structure",
"weight":"450-650 kg","height":"130-135 cm","temperament":"Docile and calm"},
"Red Dane cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Denmark",
"characteristics":"Hardy, disease resistant, excellent meat quality, easy calving",
"milk_yield":"8000-10000 liters per lactation",
"special_features":"Red to dark mahogany color with white markings, good heat tolerance",
"weight":"550-650 kg","height":"135-145 cm","temperament":"Gentle and cooperative"},
"kankarej":{"type":"Indigenous Dual-purpose (Dairy & Draught)","origin":"Gujarat, India (Kankrej territory)",
"characteristics":"Active, strong draught animal, drought resistant, disease resistant",
"milk_yield":"1500-2000 liters per lactation",
"special_features":"Silver to gray to steel black color, lyre-shaped horns, large pendulous ears",
"weight":"400-500 kg","height":"125-135 cm","temperament":"Active and energetic"},
"sahiwal":{"type":"Indigenous Dairy Cow","origin":"Punjab, Pakistan/India",
"characteristics":"Heat resistant, tick resistant, high disease resistance, docile",
"milk_yield":"2500-3200 liters per lactation",
"special_features":"Brownish red to grayish red color, loose dewlap, compact build",
"weight":"300-400 kg","height":"115-125 cm","temperament":"Docile and hardy"},
"sahiwal cross":{"type":"Crossbred Dairy Cow","origin":"Cross breeding programs (Sahiwal x exotic breeds)",
"characteristics":"Hybrid vigor, improved milk yield, better adaptability than pure exotic",
"milk_yield":"3000-4200 liters per lactation",
"special_features":"Variable color depending on cross, moderate heat tolerance, enhanced productivity",
"weight":"350-450 kg","height":"120-130 cm","temperament":"Balanced and adaptable"},
"sibbi":{"type":"Indigenous Dual-purpose (Draught & Beef)","origin":"Sibi, Baluchistan, Pakistan",
"characteristics":"Largest Zebu breed, exceptional size, extremely hardy, massive build",
"milk_yield":"1500-2200 liters per lactation",
"special_features":"Pure white to grey with black neck, tallest cattle breed, exhibited at Sibi Mela",
"weight":"500-800 kg","height":"140-160 cm","temperament":"Majestic and calm"}
}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# ---------------------------
# Helpers
# ---------------------------
def strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
clean = {}
for k, v in state_dict.items():
if k.startswith("module."):
clean[k[7:]] = v
else:
clean[k] = v
return clean
def file_to_path(file_obj) -> str:
if isinstance(file_obj, str):
return file_obj
if hasattr(file_obj, "name"):
return file_obj.name
if isinstance(file_obj, dict):
return file_obj.get("name") or file_obj.get("path") or file_obj.get("file")
raise ValueError("Unsupported file input type")
def make_head(in_dim: int, num_classes: int) -> nn.Module:
return nn.Sequential(nn.Dropout(0.2), nn.Linear(in_dim, num_classes))
# ---------------------------
# Classifier
# ---------------------------
class IndianBovineClassifier:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.model_type = "demo"
self.class_names = list(DEFAULT_CLASSES)
self.num_classes = len(self.class_names)
self.preprocess = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
self._try_autoload()
def _build_arch(self, arch: str, num_classes: int) -> nn.Module:
a = (arch or "").strip()
if a and HAS_TIMM:
try:
m = timm.create_model(a, pretrained=False, num_classes=num_classes)
cfg = getattr(m, "default_cfg", None)
if cfg:
size = cfg.get("input_size", (3, 224, 224))[-1]
mean = list(cfg.get("mean", IMAGENET_MEAN))
std = list(cfg.get("std", IMAGENET_STD))
self.preprocess = T.Compose([
T.Resize((size, size)),
T.ToTensor(),
T.Normalize(mean, std),
])
return m
except Exception:
pass
if a.lower() in {"resnet18", "tv_resnet18"}:
m = tv_models.resnet18(weights=None)
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
if a.lower() in {"efficientnet_b0", "tv_efficientnet_b0"}:
m = tv_models.efficientnet_b0(weights=None)
in_dim = m.classifier[1].in_features
m.classifier = make_head(in_dim, num_classes)
return m
# fallback
return self._simple_cnn(num_classes)
def _simple_cnn(self, nc: int) -> nn.Module:
class Simple(nn.Module):
def __init__(self, out_dim):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(True),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, out_dim))
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
return self.classifier(x)
return Simple(nc)
def _try_autoload(self):
# Attempt common filenames; quietly fall back to demo if none
candidates = [
("indian_bovine_breeds.pth","pytorch"),
("indian_bovine_model.pth","pytorch"),
("model.pth","pytorch"),
("indian_bovine_breeds.pkl","pickle"),
("model.pkl","pickle"),
("indian_bovine_breeds.joblib","joblib"),
("model.joblib","joblib"),
]
for path, kind in candidates:
if os.path.exists(path):
try:
self._load_from_path(path, kind=kind)
print(f"Loaded model: {path}")
return
except Exception as e:
print(f"Autoload failed for {path}: {e}")
self.model = self._simple_cnn(self.num_classes).to(self.device).eval()
self.model_type = "demo"
def _maybe_set_classes_from_meta(self, meta: dict) -> bool:
keys = ["classes", "class_names", "labels", "breeds"]
for k in keys:
if k in meta and isinstance(meta[k], (list, tuple)) and len(meta[k]) > 1:
self.class_names = list(meta[k])
self.num_classes = len(self.class_names)
return True
if "class_to_idx" in meta and isinstance(meta["class_to_idx"], dict):
inv = {v: k for k, v in meta["class_to_idx"].items()}
self.class_names = [inv[i] for i in range(len(inv))]
self.num_classes = len(self.class_names)
return True
return False
def _load_pytorch_checkpoint(self, ckpt):
if isinstance(ckpt, dict):
arch = ckpt.get("arch")
self._maybe_set_classes_from_meta(ckpt)
nc = ckpt.get("num_classes", self.num_classes)
state = ckpt.get("model_state_dict", ckpt.get("state_dict"))
if state is None and all(isinstance(k, str) for k in ckpt.keys()):
state = ckpt # raw state dict
if state is None:
raise ValueError("No state_dict in checkpoint.")
state = strip_module_prefix(state)
model = self._build_arch(arch or "efficientnet_b0", nc)
# ensure classifier head matches
if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential):
last = model.classifier[-1]
if isinstance(last, nn.Linear) and last.out_features != nc:
model.classifier[-1] = nn.Linear(last.in_features, nc)
elif hasattr(model, "fc") and isinstance(model.fc, nn.Linear) and model.fc.out_features != nc:
model.fc = nn.Linear(model.fc.in_features, nc)
model.load_state_dict(state, strict=False)
self.num_classes = nc
self.model = model.to(self.device).eval()
self.model_type = f"pytorch:{arch or 'tv_efficientnet_b0'}"
else:
# direct serialized torch.nn.Module
self.model = ckpt.to(self.device).eval()
self.model_type = "pytorch:serialized"
def _load_generic_object(self, obj):
if hasattr(obj, "eval") and hasattr(obj, "state_dict"):
self.model = obj.to(self.device).eval()
self.model_type = "pytorch:pickle"
elif hasattr(obj, "predict_proba"):
self.model = obj
self.model_type = "sklearn"
else:
raise ValueError("Unsupported object in file (expect torch module/state_dict or sklearn estimator).")
def _load_from_path(self, path: str, kind: str = "auto"):
ext = os.path.splitext(path)[1].lower()
if kind == "auto":
if ext in {".pth"}:
kind = "pytorch"
elif ext in {".pkl"}:
kind = "pickle"
elif ext in {".joblib"}:
kind = "joblib"
else:
kind = "pytorch"
if kind in ("pytorch", "pickle"):
# Prefer torch.load first for torch checkpoints, even if extension is .pkl
try:
ckpt = torch.load(path, map_location=self.device)
self._load_pytorch_checkpoint(ckpt)
return
except Exception as torch_err:
if kind == "pytorch":
raise RuntimeError(f"PyTorch load failed: {torch_err}") from torch_err
# try sklearn-style pickle below
# sklearn pickle fallback
try:
with open(path, "rb") as f:
obj = pickle.load(f)
self._load_generic_object(obj)
return
except pickle.UnpicklingError as pe:
# Likely a torch checkpoint mislabeled as .pkl
raise RuntimeError(
"This .pkl appears to be a PyTorch checkpoint; load via torch.load or rename to .pth."
) from pe
if kind == "joblib":
obj = joblib.load(path)
self._load_generic_object(obj)
return
raise ValueError(f"Unknown model kind: {kind}")
# public API for UI
def load_user_model(self, file_obj) -> str:
path = file_to_path(file_obj)
self._load_from_path(path, kind="auto")
return f"βœ… Loaded model: {os.path.basename(path)} | Type: {self.model_type} | Classes: {self.num_classes}"
def load_classes_json(self, file_obj) -> str:
path = file_to_path(file_obj)
with open(path, "r", encoding="utf-8") as f:
names = json.load(f)
if not isinstance(names, list) or len(names) < 2:
raise ValueError("classes.json must be a list with 2 or more class names.")
self.class_names = list(names)
self.num_classes = len(names)
return f"βœ… Loaded {len(names)} class names from {os.path.basename(path)}"
# inference
def preprocess_img(self, image: Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
if self.model_type.startswith("pytorch") or self.model_type == "demo":
x = self.preprocess(image).unsqueeze(0).to(self.device)
return x
else:
arr = np.array(image.resize((224, 224))).astype(np.float32) / 255.0
return arr.flatten().reshape(1, -1)
def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]:
if self.model is None:
return {"Error": "Model not loaded"}, "Unknown"
x = self.preprocess_img(image)
if self.model_type.startswith("pytorch") or self.model_type == "demo":
with torch.no_grad():
if self.model_type == "demo":
np.random.seed(hash(str(image.size)) % (2**32))
probs = np.random.dirichlet(np.ones(self.num_classes) * 3.0)
else:
logits = self.model(x)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
elif self.model_type == "sklearn":
probs = self.model.predict_proba(x)[0]
else:
np.random.seed(42)
probs = np.random.dirichlet(np.ones(self.num_classes) * 2.0)
top_idx = np.argsort(probs)[::-1][:3]
results = {f"Top {i+1}: {self.class_names[idx]}": float(probs[idx]) for i, idx in enumerate(top_idx)}
return results, self.class_names[top_idx[0]]
# ---------------------------
# UI callbacks
# ---------------------------
classifier = IndianBovineClassifier()
def classify_image(image: Image.Image):
if image is None:
return (
"Please upload an image of cattle or buffalo",
"Upload an image to see detailed breed information",
"| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting image upload |",
)
preds, top_breed = classifier.predict(image)
if "Error" in preds:
msg = preds["Error"]
return (
f"❌ {msg}",
"Error occurred during classification",
f"| Attribute | Value |\n|-----------|-------|\n| Status | Error: {msg} |",
)
indicator = "DEMO - " if classifier.model_type == "demo" else f"{classifier.model_type} - "
md = f"{indicator}Classification Results:\n\n"
for k, v in preds.items():
md += f"- {k}: {v:.2%}\n"
if classifier.model_type == "demo":
md += "\nDemo mode: Upload a .pth/.pkl/.joblib model for real predictions."
if top_breed in BREED_INFO:
info = BREED_INFO[top_breed]
desc = f"""
## πŸ„ {top_breed}
Type: {info['type']}
Origin: {info['origin']}
Characteristics: {info['characteristics']}
Milk Yield: {info['milk_yield']}
Special Features: {info['special_features']}
Weight: {info['weight']}
Height: {info['height']}
Temperament: {info['temperament']}
"""
table = f"""| Attribute | Value |
|-----------|-------|
| Type | {info['type']} |
| Origin | {info['origin']} |
| Weight | {info['weight']} |
| Height | {info['height']} |
| Milk Yield | {info['milk_yield']} |
| Temperament | {info['temperament']} |"""
else:
desc = "Detailed information not available for this breed."
table = "| Attribute | Value |\n|-----------|-------|\n| Status | Information not available |"
return md, desc, table
def upload_and_load_model(file_obj):
if not file_obj:
return "Please select a .pth, .pkl or .joblib file to load."
try:
return classifier.load_user_model(file_obj)
except Exception as e:
return f"❌ Failed to load model: {e}"
def upload_classes(file_obj):
if not file_obj:
return "Please select a classes.json file."
try:
return classifier.load_classes_json(file_obj)
except Exception as e:
return f"❌ Failed to load classes.json: {e}"
# ---------------------------
# Minimal, responsive CSS
# ---------------------------
CUSTOM_CSS = """
.gradio-container { min-height: 100vh; }
.header { text-align:center; padding: 1rem; }
.header .title { font-size: 2em; font-weight: 700; }
.footer { text-align:center; opacity:.75; padding:.75rem; }
@media (max-width: 768px) {
.title { font-size: 1.6em !important; }
}
"""
# ---------------------------
# Interface
# ---------------------------
def create_interface():
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), fill_width=True, title="Indian Bovine Classifier") as app:
gr.HTML(f"""
<div class="header">
<div class="title">Indian Bovine Breeds Classifier</div>
<div>PyTorch runtime β€’ {len(DEFAULT_CLASSES)} default classes β€’ Device: {classifier.device}</div>
</div>
""")
# Collapsible sidebar
with gr.Sidebar():
gr.Markdown("### Model loader")
model_file = gr.File(label="Upload .pth / .pkl / .joblib", file_types=[".pth",".pkl",".joblib"], file_count="single")
load_btn = gr.Button("Load model", variant="primary")
load_status = gr.Markdown("Model status: DEMO (no trained weights loaded)")
load_btn.click(upload_and_load_model, inputs=[model_file], outputs=[load_status])
gr.Markdown("### Classes mapping")
classes_file = gr.File(label="Upload classes.json", file_types=[".json"], file_count="single")
load_classes_btn = gr.Button("Load classes.json")
classes_status = gr.Markdown("No external classes.json loaded.")
load_classes_btn.click(upload_classes, inputs=[classes_file], outputs=[classes_status])
# Main canvas
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=320, variant="panel"):
gr.Markdown("### Upload image")
image_input = gr.Image(type="pil", label="Cattle/Buffalo image")
classify_btn = gr.Button("Classify", variant="secondary")
with gr.Column(scale=2, min_width=360, variant="panel"):
with gr.Tab("Results"):
prediction_output = gr.Markdown(value="Upload an image to see classification.")
with gr.Tab("Breed info"):
breed_info_output = gr.Markdown(value="Breed info will appear here.")
with gr.Tab("Stats"):
breed_stats_table = gr.Markdown(value="| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting classification... |")
gr.Markdown(f"""<div class="footer">Model type: {classifier.model_type} β€’ PyTorch {torch.__version__}</div>""")
# Wiring
classify_btn.click(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
image_input.change(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table])
return app
if __name__ == "__main__":
app = create_interface()
# Launch controls via env vars (optional)
share_flag = os.getenv("GRADIO_SHARE", "0").lower() in {"1", "true", "yes"}
ssr_flag = os.getenv("GRADIO_SSR_MODE", "true").lower() in {"1", "true", "yes"}
app.launch(server_name="0.0.0.0", server_port=7860, share=share_flag, ssr_mode=ssr_flag)