Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import json | |
| import pickle | |
| import joblib | |
| import warnings | |
| from typing import Tuple, Dict | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as tv_models | |
| import torchvision.transforms as T | |
| import numpy as np | |
| from PIL import Image | |
| warnings.filterwarnings("ignore") | |
| # Optional timm | |
| try: | |
| import timm | |
| HAS_TIMM = True | |
| except Exception: | |
| HAS_TIMM = False | |
| # --------------------------- | |
| # Defaults & metadata | |
| # --------------------------- | |
| DEFAULT_CLASSES = [ | |
| "Ayrshire cattle","Brown Swiss cattle","Holstein Friesian cattle", | |
| "Jaffrabadi","Jersey cattle","Murrah","Red Dane cattle", | |
| "kankarej","sahiwal","sahiwal cross","sibbi" | |
| ] | |
| BREED_INFO = { | |
| "Ayrshire cattle":{"type":"Dairy Cow","origin":"Scotland", | |
| "characteristics":"Strong, adaptable, excellent udder conformation and superior grazing ability", | |
| "milk_yield":"6000-7000 liters per lactation", | |
| "special_features":"Red and white patches, hardy in cold weather, high butterfat content", | |
| "weight":"450-550 kg","height":"125-135 cm","temperament":"Docile and friendly"}, | |
| "Brown Swiss cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Switzerland", | |
| "characteristics":"Docile, strong, excellent for cheese production, disease resistant", | |
| "milk_yield":"10000-14000 liters per lactation", | |
| "special_features":"Light to dark brown color with creamy white muzzle, exceptional longevity", | |
| "weight":"600-700 kg","height":"135-150 cm","temperament":"Calm and intelligent"}, | |
| "Holstein Friesian cattle":{"type":"Dairy Cow","origin":"Netherlands/Germany", | |
| "characteristics":"Highest milk production, excellent feed conversion, docile temperament", | |
| "milk_yield":"8000-12000 liters per lactation", | |
| "special_features":"Distinctive black and white patches, large frame, heat sensitive", | |
| "weight":"580-700 kg","height":"140-150 cm","temperament":"Gentle and manageable"}, | |
| "Jaffrabadi":{"type":"Indigenous Dairy Buffalo","origin":"Gujarat, India (Saurashtra region)", | |
| "characteristics":"Heaviest Indian buffalo breed, adapted to harsh semi-arid conditions", | |
| "milk_yield":"2000-2500 liters per lactation", | |
| "special_features":"Black color, dome-shaped forehead, ring-like horns, highest butterfat content", | |
| "weight":"400-600 kg","height":"130-140 cm","temperament":"Hardy and resilient"}, | |
| "Jersey cattle":{"type":"Dairy Cow","origin":"Jersey, Channel Islands", | |
| "characteristics":"Efficient feed conversion, calving ease, heat tolerant, docile", | |
| "milk_yield":"4500-6500 liters per lactation", | |
| "special_features":"Light tan to fawn color, smallest dairy breed, highest butterfat percentage", | |
| "weight":"350-450 kg","height":"120-125 cm","temperament":"Alert and intelligent"}, | |
| "Murrah":{"type":"Indigenous Dairy Buffalo","origin":"Haryana and Punjab, India", | |
| "characteristics":"Highest milk yielding buffalo breed, docile nature, good mothers", | |
| "milk_yield":"2200-3000 liters per lactation", | |
| "special_features":"Jet black color, tightly curved horns, compact body structure", | |
| "weight":"450-650 kg","height":"130-135 cm","temperament":"Docile and calm"}, | |
| "Red Dane cattle":{"type":"Dual-purpose (Dairy & Beef)","origin":"Denmark", | |
| "characteristics":"Hardy, disease resistant, excellent meat quality, easy calving", | |
| "milk_yield":"8000-10000 liters per lactation", | |
| "special_features":"Red to dark mahogany color with white markings, good heat tolerance", | |
| "weight":"550-650 kg","height":"135-145 cm","temperament":"Gentle and cooperative"}, | |
| "kankarej":{"type":"Indigenous Dual-purpose (Dairy & Draught)","origin":"Gujarat, India (Kankrej territory)", | |
| "characteristics":"Active, strong draught animal, drought resistant, disease resistant", | |
| "milk_yield":"1500-2000 liters per lactation", | |
| "special_features":"Silver to gray to steel black color, lyre-shaped horns, large pendulous ears", | |
| "weight":"400-500 kg","height":"125-135 cm","temperament":"Active and energetic"}, | |
| "sahiwal":{"type":"Indigenous Dairy Cow","origin":"Punjab, Pakistan/India", | |
| "characteristics":"Heat resistant, tick resistant, high disease resistance, docile", | |
| "milk_yield":"2500-3200 liters per lactation", | |
| "special_features":"Brownish red to grayish red color, loose dewlap, compact build", | |
| "weight":"300-400 kg","height":"115-125 cm","temperament":"Docile and hardy"}, | |
| "sahiwal cross":{"type":"Crossbred Dairy Cow","origin":"Cross breeding programs (Sahiwal x exotic breeds)", | |
| "characteristics":"Hybrid vigor, improved milk yield, better adaptability than pure exotic", | |
| "milk_yield":"3000-4200 liters per lactation", | |
| "special_features":"Variable color depending on cross, moderate heat tolerance, enhanced productivity", | |
| "weight":"350-450 kg","height":"120-130 cm","temperament":"Balanced and adaptable"}, | |
| "sibbi":{"type":"Indigenous Dual-purpose (Draught & Beef)","origin":"Sibi, Baluchistan, Pakistan", | |
| "characteristics":"Largest Zebu breed, exceptional size, extremely hardy, massive build", | |
| "milk_yield":"1500-2200 liters per lactation", | |
| "special_features":"Pure white to grey with black neck, tallest cattle breed, exhibited at Sibi Mela", | |
| "weight":"500-800 kg","height":"140-160 cm","temperament":"Majestic and calm"} | |
| } | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # --------------------------- | |
| # Helpers | |
| # --------------------------- | |
| def strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| clean = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| clean[k[7:]] = v | |
| else: | |
| clean[k] = v | |
| return clean | |
| def file_to_path(file_obj) -> str: | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| if hasattr(file_obj, "name"): | |
| return file_obj.name | |
| if isinstance(file_obj, dict): | |
| return file_obj.get("name") or file_obj.get("path") or file_obj.get("file") | |
| raise ValueError("Unsupported file input type") | |
| def make_head(in_dim: int, num_classes: int) -> nn.Module: | |
| return nn.Sequential(nn.Dropout(0.2), nn.Linear(in_dim, num_classes)) | |
| # --------------------------- | |
| # Classifier | |
| # --------------------------- | |
| class IndianBovineClassifier: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.model_type = "demo" | |
| self.class_names = list(DEFAULT_CLASSES) | |
| self.num_classes = len(self.class_names) | |
| self.preprocess = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(IMAGENET_MEAN, IMAGENET_STD), | |
| ]) | |
| self._try_autoload() | |
| def _build_arch(self, arch: str, num_classes: int) -> nn.Module: | |
| a = (arch or "").strip() | |
| if a and HAS_TIMM: | |
| try: | |
| m = timm.create_model(a, pretrained=False, num_classes=num_classes) | |
| cfg = getattr(m, "default_cfg", None) | |
| if cfg: | |
| size = cfg.get("input_size", (3, 224, 224))[-1] | |
| mean = list(cfg.get("mean", IMAGENET_MEAN)) | |
| std = list(cfg.get("std", IMAGENET_STD)) | |
| self.preprocess = T.Compose([ | |
| T.Resize((size, size)), | |
| T.ToTensor(), | |
| T.Normalize(mean, std), | |
| ]) | |
| return m | |
| except Exception: | |
| pass | |
| if a.lower() in {"resnet18", "tv_resnet18"}: | |
| m = tv_models.resnet18(weights=None) | |
| m.fc = nn.Linear(m.fc.in_features, num_classes) | |
| return m | |
| if a.lower() in {"efficientnet_b0", "tv_efficientnet_b0"}: | |
| m = tv_models.efficientnet_b0(weights=None) | |
| in_dim = m.classifier[1].in_features | |
| m.classifier = make_head(in_dim, num_classes) | |
| return m | |
| # fallback | |
| return self._simple_cnn(num_classes) | |
| def _simple_cnn(self, nc: int) -> nn.Module: | |
| class Simple(nn.Module): | |
| def __init__(self, out_dim): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(True), nn.MaxPool2d(2), | |
| nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(True), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| ) | |
| self.classifier = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, out_dim)) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = torch.flatten(x, 1) | |
| return self.classifier(x) | |
| return Simple(nc) | |
| def _try_autoload(self): | |
| # Attempt common filenames; quietly fall back to demo if none | |
| candidates = [ | |
| ("indian_bovine_breeds.pth","pytorch"), | |
| ("indian_bovine_model.pth","pytorch"), | |
| ("model.pth","pytorch"), | |
| ("indian_bovine_breeds.pkl","pickle"), | |
| ("model.pkl","pickle"), | |
| ("indian_bovine_breeds.joblib","joblib"), | |
| ("model.joblib","joblib"), | |
| ] | |
| for path, kind in candidates: | |
| if os.path.exists(path): | |
| try: | |
| self._load_from_path(path, kind=kind) | |
| print(f"Loaded model: {path}") | |
| return | |
| except Exception as e: | |
| print(f"Autoload failed for {path}: {e}") | |
| self.model = self._simple_cnn(self.num_classes).to(self.device).eval() | |
| self.model_type = "demo" | |
| def _maybe_set_classes_from_meta(self, meta: dict) -> bool: | |
| keys = ["classes", "class_names", "labels", "breeds"] | |
| for k in keys: | |
| if k in meta and isinstance(meta[k], (list, tuple)) and len(meta[k]) > 1: | |
| self.class_names = list(meta[k]) | |
| self.num_classes = len(self.class_names) | |
| return True | |
| if "class_to_idx" in meta and isinstance(meta["class_to_idx"], dict): | |
| inv = {v: k for k, v in meta["class_to_idx"].items()} | |
| self.class_names = [inv[i] for i in range(len(inv))] | |
| self.num_classes = len(self.class_names) | |
| return True | |
| return False | |
| def _load_pytorch_checkpoint(self, ckpt): | |
| if isinstance(ckpt, dict): | |
| arch = ckpt.get("arch") | |
| self._maybe_set_classes_from_meta(ckpt) | |
| nc = ckpt.get("num_classes", self.num_classes) | |
| state = ckpt.get("model_state_dict", ckpt.get("state_dict")) | |
| if state is None and all(isinstance(k, str) for k in ckpt.keys()): | |
| state = ckpt # raw state dict | |
| if state is None: | |
| raise ValueError("No state_dict in checkpoint.") | |
| state = strip_module_prefix(state) | |
| model = self._build_arch(arch or "efficientnet_b0", nc) | |
| # ensure classifier head matches | |
| if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential): | |
| last = model.classifier[-1] | |
| if isinstance(last, nn.Linear) and last.out_features != nc: | |
| model.classifier[-1] = nn.Linear(last.in_features, nc) | |
| elif hasattr(model, "fc") and isinstance(model.fc, nn.Linear) and model.fc.out_features != nc: | |
| model.fc = nn.Linear(model.fc.in_features, nc) | |
| model.load_state_dict(state, strict=False) | |
| self.num_classes = nc | |
| self.model = model.to(self.device).eval() | |
| self.model_type = f"pytorch:{arch or 'tv_efficientnet_b0'}" | |
| else: | |
| # direct serialized torch.nn.Module | |
| self.model = ckpt.to(self.device).eval() | |
| self.model_type = "pytorch:serialized" | |
| def _load_generic_object(self, obj): | |
| if hasattr(obj, "eval") and hasattr(obj, "state_dict"): | |
| self.model = obj.to(self.device).eval() | |
| self.model_type = "pytorch:pickle" | |
| elif hasattr(obj, "predict_proba"): | |
| self.model = obj | |
| self.model_type = "sklearn" | |
| else: | |
| raise ValueError("Unsupported object in file (expect torch module/state_dict or sklearn estimator).") | |
| def _load_from_path(self, path: str, kind: str = "auto"): | |
| ext = os.path.splitext(path)[1].lower() | |
| if kind == "auto": | |
| if ext in {".pth"}: | |
| kind = "pytorch" | |
| elif ext in {".pkl"}: | |
| kind = "pickle" | |
| elif ext in {".joblib"}: | |
| kind = "joblib" | |
| else: | |
| kind = "pytorch" | |
| if kind in ("pytorch", "pickle"): | |
| # Prefer torch.load first for torch checkpoints, even if extension is .pkl | |
| try: | |
| ckpt = torch.load(path, map_location=self.device) | |
| self._load_pytorch_checkpoint(ckpt) | |
| return | |
| except Exception as torch_err: | |
| if kind == "pytorch": | |
| raise RuntimeError(f"PyTorch load failed: {torch_err}") from torch_err | |
| # try sklearn-style pickle below | |
| # sklearn pickle fallback | |
| try: | |
| with open(path, "rb") as f: | |
| obj = pickle.load(f) | |
| self._load_generic_object(obj) | |
| return | |
| except pickle.UnpicklingError as pe: | |
| # Likely a torch checkpoint mislabeled as .pkl | |
| raise RuntimeError( | |
| "This .pkl appears to be a PyTorch checkpoint; load via torch.load or rename to .pth." | |
| ) from pe | |
| if kind == "joblib": | |
| obj = joblib.load(path) | |
| self._load_generic_object(obj) | |
| return | |
| raise ValueError(f"Unknown model kind: {kind}") | |
| # public API for UI | |
| def load_user_model(self, file_obj) -> str: | |
| path = file_to_path(file_obj) | |
| self._load_from_path(path, kind="auto") | |
| return f"β Loaded model: {os.path.basename(path)} | Type: {self.model_type} | Classes: {self.num_classes}" | |
| def load_classes_json(self, file_obj) -> str: | |
| path = file_to_path(file_obj) | |
| with open(path, "r", encoding="utf-8") as f: | |
| names = json.load(f) | |
| if not isinstance(names, list) or len(names) < 2: | |
| raise ValueError("classes.json must be a list with 2 or more class names.") | |
| self.class_names = list(names) | |
| self.num_classes = len(names) | |
| return f"β Loaded {len(names)} class names from {os.path.basename(path)}" | |
| # inference | |
| def preprocess_img(self, image: Image.Image): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| if self.model_type.startswith("pytorch") or self.model_type == "demo": | |
| x = self.preprocess(image).unsqueeze(0).to(self.device) | |
| return x | |
| else: | |
| arr = np.array(image.resize((224, 224))).astype(np.float32) / 255.0 | |
| return arr.flatten().reshape(1, -1) | |
| def predict(self, image: Image.Image) -> Tuple[Dict[str, float], str]: | |
| if self.model is None: | |
| return {"Error": "Model not loaded"}, "Unknown" | |
| x = self.preprocess_img(image) | |
| if self.model_type.startswith("pytorch") or self.model_type == "demo": | |
| with torch.no_grad(): | |
| if self.model_type == "demo": | |
| np.random.seed(hash(str(image.size)) % (2**32)) | |
| probs = np.random.dirichlet(np.ones(self.num_classes) * 3.0) | |
| else: | |
| logits = self.model(x) | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| elif self.model_type == "sklearn": | |
| probs = self.model.predict_proba(x)[0] | |
| else: | |
| np.random.seed(42) | |
| probs = np.random.dirichlet(np.ones(self.num_classes) * 2.0) | |
| top_idx = np.argsort(probs)[::-1][:3] | |
| results = {f"Top {i+1}: {self.class_names[idx]}": float(probs[idx]) for i, idx in enumerate(top_idx)} | |
| return results, self.class_names[top_idx[0]] | |
| # --------------------------- | |
| # UI callbacks | |
| # --------------------------- | |
| classifier = IndianBovineClassifier() | |
| def classify_image(image: Image.Image): | |
| if image is None: | |
| return ( | |
| "Please upload an image of cattle or buffalo", | |
| "Upload an image to see detailed breed information", | |
| "| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting image upload |", | |
| ) | |
| preds, top_breed = classifier.predict(image) | |
| if "Error" in preds: | |
| msg = preds["Error"] | |
| return ( | |
| f"β {msg}", | |
| "Error occurred during classification", | |
| f"| Attribute | Value |\n|-----------|-------|\n| Status | Error: {msg} |", | |
| ) | |
| indicator = "DEMO - " if classifier.model_type == "demo" else f"{classifier.model_type} - " | |
| md = f"{indicator}Classification Results:\n\n" | |
| for k, v in preds.items(): | |
| md += f"- {k}: {v:.2%}\n" | |
| if classifier.model_type == "demo": | |
| md += "\nDemo mode: Upload a .pth/.pkl/.joblib model for real predictions." | |
| if top_breed in BREED_INFO: | |
| info = BREED_INFO[top_breed] | |
| desc = f""" | |
| ## π {top_breed} | |
| Type: {info['type']} | |
| Origin: {info['origin']} | |
| Characteristics: {info['characteristics']} | |
| Milk Yield: {info['milk_yield']} | |
| Special Features: {info['special_features']} | |
| Weight: {info['weight']} | |
| Height: {info['height']} | |
| Temperament: {info['temperament']} | |
| """ | |
| table = f"""| Attribute | Value | | |
| |-----------|-------| | |
| | Type | {info['type']} | | |
| | Origin | {info['origin']} | | |
| | Weight | {info['weight']} | | |
| | Height | {info['height']} | | |
| | Milk Yield | {info['milk_yield']} | | |
| | Temperament | {info['temperament']} |""" | |
| else: | |
| desc = "Detailed information not available for this breed." | |
| table = "| Attribute | Value |\n|-----------|-------|\n| Status | Information not available |" | |
| return md, desc, table | |
| def upload_and_load_model(file_obj): | |
| if not file_obj: | |
| return "Please select a .pth, .pkl or .joblib file to load." | |
| try: | |
| return classifier.load_user_model(file_obj) | |
| except Exception as e: | |
| return f"β Failed to load model: {e}" | |
| def upload_classes(file_obj): | |
| if not file_obj: | |
| return "Please select a classes.json file." | |
| try: | |
| return classifier.load_classes_json(file_obj) | |
| except Exception as e: | |
| return f"β Failed to load classes.json: {e}" | |
| # --------------------------- | |
| # Minimal, responsive CSS | |
| # --------------------------- | |
| CUSTOM_CSS = """ | |
| .gradio-container { min-height: 100vh; } | |
| .header { text-align:center; padding: 1rem; } | |
| .header .title { font-size: 2em; font-weight: 700; } | |
| .footer { text-align:center; opacity:.75; padding:.75rem; } | |
| @media (max-width: 768px) { | |
| .title { font-size: 1.6em !important; } | |
| } | |
| """ | |
| # --------------------------- | |
| # Interface | |
| # --------------------------- | |
| def create_interface(): | |
| with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(), fill_width=True, title="Indian Bovine Classifier") as app: | |
| gr.HTML(f""" | |
| <div class="header"> | |
| <div class="title">Indian Bovine Breeds Classifier</div> | |
| <div>PyTorch runtime β’ {len(DEFAULT_CLASSES)} default classes β’ Device: {classifier.device}</div> | |
| </div> | |
| """) | |
| # Collapsible sidebar | |
| with gr.Sidebar(): | |
| gr.Markdown("### Model loader") | |
| model_file = gr.File(label="Upload .pth / .pkl / .joblib", file_types=[".pth",".pkl",".joblib"], file_count="single") | |
| load_btn = gr.Button("Load model", variant="primary") | |
| load_status = gr.Markdown("Model status: DEMO (no trained weights loaded)") | |
| load_btn.click(upload_and_load_model, inputs=[model_file], outputs=[load_status]) | |
| gr.Markdown("### Classes mapping") | |
| classes_file = gr.File(label="Upload classes.json", file_types=[".json"], file_count="single") | |
| load_classes_btn = gr.Button("Load classes.json") | |
| classes_status = gr.Markdown("No external classes.json loaded.") | |
| load_classes_btn.click(upload_classes, inputs=[classes_file], outputs=[classes_status]) | |
| # Main canvas | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=320, variant="panel"): | |
| gr.Markdown("### Upload image") | |
| image_input = gr.Image(type="pil", label="Cattle/Buffalo image") | |
| classify_btn = gr.Button("Classify", variant="secondary") | |
| with gr.Column(scale=2, min_width=360, variant="panel"): | |
| with gr.Tab("Results"): | |
| prediction_output = gr.Markdown(value="Upload an image to see classification.") | |
| with gr.Tab("Breed info"): | |
| breed_info_output = gr.Markdown(value="Breed info will appear here.") | |
| with gr.Tab("Stats"): | |
| breed_stats_table = gr.Markdown(value="| Attribute | Value |\n|-----------|-------|\n| Status | Awaiting classification... |") | |
| gr.Markdown(f"""<div class="footer">Model type: {classifier.model_type} β’ PyTorch {torch.__version__}</div>""") | |
| # Wiring | |
| classify_btn.click(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table]) | |
| image_input.change(classify_image, inputs=[image_input], outputs=[prediction_output, breed_info_output, breed_stats_table]) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| # Launch controls via env vars (optional) | |
| share_flag = os.getenv("GRADIO_SHARE", "0").lower() in {"1", "true", "yes"} | |
| ssr_flag = os.getenv("GRADIO_SSR_MODE", "true").lower() in {"1", "true", "yes"} | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=share_flag, ssr_mode=ssr_flag) | |