from IPython.display import display, JSON import matplotlib.pyplot as plt from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet import numpy as np import time import gradio as gr import json import cv2 import os from huggingface_hub import batch_bucket_files # ------------------------------------------------------ # HF TOKEN (IMPORTANT) # ------------------------------------------------------ HF_TOKEN = os.environ.get("HF_TOKEN") # set in Spaces secrets BUCKET_ID = "codewithRiz/Buck_data_storage" # ------------------------------------------------------ # LOAD MODEL # ------------------------------------------------------ print("Default SpeciesNet model:", DEFAULT_MODEL) print("Supported SpeciesNet models:", SUPPORTED_MODELS) model = SpeciesNet(DEFAULT_MODEL) # ------------------------------------------------------ # VALIDATION # ------------------------------------------------------ def validate_predictions_structure(pred): required_keys = ["filepath", "detections", "classifications"] for key in required_keys: if key not in pred: raise ValueError(f"Missing key '{key}'") if not isinstance(pred["detections"], list): raise ValueError("detections must be list") cls = pred["classifications"] if "classes" not in cls or "scores" not in cls: raise ValueError("classification format invalid") return True def validate_model_output(predictions_dict): if "predictions" not in predictions_dict: raise ValueError("Missing predictions") for pred in predictions_dict["predictions"]: validate_predictions_structure(pred) # ------------------------------------------------------ # SAVE YOLO TXT FORMAT # ------------------------------------------------------ def save_yolo_annotations(image_path, predictions_dict, txt_path): """ Format: class_name x_center y_center width height (normalized) """ img = cv2.imread(image_path) h, w, _ = img.shape lines = [] for pred in predictions_dict.get("predictions", []): detections = pred.get("detections", []) classes = pred.get("classifications", {}).get("classes", []) if not classes: continue class_name = classes[0].split(";")[-1] for det in detections: x, y, bw, bh = det["bbox"] x_center = x + bw / 2 y_center = y + bh / 2 lines.append(f"{class_name} {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}") with open(txt_path, "w") as f: f.write("\n".join(lines)) # ------------------------------------------------------ # UPLOAD TO BUCKET (SAFE) # ------------------------------------------------------ def upload_to_bucket(image_path, txt_path, image_id): if HF_TOKEN is None: print("⚠ HF_TOKEN missing → skipping upload") return try: batch_bucket_files( BUCKET_ID, add=[ (image_path, f"images/{image_id}.jpg"), (txt_path, f"labels/{image_id}.txt"), ], token=HF_TOKEN ) print("✅ Uploaded to bucket") except Exception as e: print("⚠ Upload failed:", str(e)) # ------------------------------------------------------ # DRAW BOXES # ------------------------------------------------------ def draw_predictions(image_path, predictions_dict): img = cv2.imread(image_path) h, w, _ = img.shape for pred in predictions_dict.get("predictions", []): detections = pred.get("detections", []) cls = pred.get("classifications", {}) classes = cls.get("classes", []) scores = cls.get("scores", []) if not classes: continue top_class = classes[0].split(";")[-1] top_score = scores[0] for det in detections: x, y, bw, bh = det["bbox"] x1 = int(x * w) y1 = int(y * h) x2 = int((x + bw) * w) y2 = int((y + bh) * h) cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) label = f"{top_class} {top_score:.2f}" cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ------------------------------------------------------ # INFERENCE FUNCTION # ------------------------------------------------------ def inference(image): image_id = str(int(time.time())) image_path = f"{image_id}.jpg" txt_path = f"{image_id}.txt" image.save(image_path) start = time.time() predictions_dict = model.predict( instances_dict={ "instances": [ { "filepath": image_path, } ] } ) end = time.time() print(f"Inference time: {end - start:.2f}s") # validate validate_model_output(predictions_dict) # save JSON with open(f"{image_id}.json", "w") as f: json.dump(predictions_dict, f, indent=4) # save YOLO txt save_yolo_annotations(image_path, predictions_dict, txt_path) # upload to HF bucket upload_to_bucket(image_path, txt_path, image_id) # visualize annotated = draw_predictions(image_path, predictions_dict) return annotated, json.dumps(predictions_dict, indent=4) # ------------------------------------------------------ # GRADIO UI # ------------------------------------------------------ iface = gr.Interface( fn=inference, inputs=gr.Image(type="pil"), outputs=[ gr.Image(label="Detection Output"), gr.JSON(label="Model Output") ], title="Wildlife Detector + SpeciesNet", description="Upload wildlife image → detect + classify + save to Hugging Face bucket" ) iface.launch()