Spaces:
Sleeping
Sleeping
| from IPython.display import display, JSON | |
| import matplotlib.pyplot as plt | |
| from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet | |
| import numpy as np | |
| import time | |
| import gradio as gr | |
| import json | |
| import cv2 | |
| import os | |
| from huggingface_hub import batch_bucket_files | |
| # ------------------------------------------------------ | |
| # HF TOKEN (IMPORTANT) | |
| # ------------------------------------------------------ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # set in Spaces secrets | |
| BUCKET_ID = "codewithRiz/Buck_data_storage" | |
| # ------------------------------------------------------ | |
| # LOAD MODEL | |
| # ------------------------------------------------------ | |
| print("Default SpeciesNet model:", DEFAULT_MODEL) | |
| print("Supported SpeciesNet models:", SUPPORTED_MODELS) | |
| model = SpeciesNet(DEFAULT_MODEL) | |
| # ------------------------------------------------------ | |
| # VALIDATION | |
| # ------------------------------------------------------ | |
| def validate_predictions_structure(pred): | |
| required_keys = ["filepath", "detections", "classifications"] | |
| for key in required_keys: | |
| if key not in pred: | |
| raise ValueError(f"Missing key '{key}'") | |
| if not isinstance(pred["detections"], list): | |
| raise ValueError("detections must be list") | |
| cls = pred["classifications"] | |
| if "classes" not in cls or "scores" not in cls: | |
| raise ValueError("classification format invalid") | |
| return True | |
| def validate_model_output(predictions_dict): | |
| if "predictions" not in predictions_dict: | |
| raise ValueError("Missing predictions") | |
| for pred in predictions_dict["predictions"]: | |
| validate_predictions_structure(pred) | |
| # ------------------------------------------------------ | |
| # SAVE YOLO TXT FORMAT | |
| # ------------------------------------------------------ | |
| def save_yolo_annotations(image_path, predictions_dict, txt_path): | |
| """ | |
| Format: | |
| class_name x_center y_center width height (normalized) | |
| """ | |
| img = cv2.imread(image_path) | |
| h, w, _ = img.shape | |
| lines = [] | |
| for pred in predictions_dict.get("predictions", []): | |
| detections = pred.get("detections", []) | |
| classes = pred.get("classifications", {}).get("classes", []) | |
| if not classes: | |
| continue | |
| class_name = classes[0].split(";")[-1] | |
| for det in detections: | |
| x, y, bw, bh = det["bbox"] | |
| x_center = x + bw / 2 | |
| y_center = y + bh / 2 | |
| lines.append(f"{class_name} {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}") | |
| with open(txt_path, "w") as f: | |
| f.write("\n".join(lines)) | |
| # ------------------------------------------------------ | |
| # UPLOAD TO BUCKET (SAFE) | |
| # ------------------------------------------------------ | |
| def upload_to_bucket(image_path, txt_path, image_id): | |
| if HF_TOKEN is None: | |
| print("⚠ HF_TOKEN missing → skipping upload") | |
| return | |
| try: | |
| batch_bucket_files( | |
| BUCKET_ID, | |
| add=[ | |
| (image_path, f"images/{image_id}.jpg"), | |
| (txt_path, f"labels/{image_id}.txt"), | |
| ], | |
| token=HF_TOKEN | |
| ) | |
| print("✅ Uploaded to bucket") | |
| except Exception as e: | |
| print("⚠ Upload failed:", str(e)) | |
| # ------------------------------------------------------ | |
| # DRAW BOXES | |
| # ------------------------------------------------------ | |
| def draw_predictions(image_path, predictions_dict): | |
| img = cv2.imread(image_path) | |
| h, w, _ = img.shape | |
| for pred in predictions_dict.get("predictions", []): | |
| detections = pred.get("detections", []) | |
| cls = pred.get("classifications", {}) | |
| classes = cls.get("classes", []) | |
| scores = cls.get("scores", []) | |
| if not classes: | |
| continue | |
| top_class = classes[0].split(";")[-1] | |
| top_score = scores[0] | |
| for det in detections: | |
| x, y, bw, bh = det["bbox"] | |
| x1 = int(x * w) | |
| y1 = int(y * h) | |
| x2 = int((x + bw) * w) | |
| y2 = int((y + bh) * h) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| label = f"{top_class} {top_score:.2f}" | |
| cv2.putText(img, label, (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, | |
| (255, 255, 255), 2) | |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # ------------------------------------------------------ | |
| # INFERENCE FUNCTION | |
| # ------------------------------------------------------ | |
| def inference(image): | |
| image_id = str(int(time.time())) | |
| image_path = f"{image_id}.jpg" | |
| txt_path = f"{image_id}.txt" | |
| image.save(image_path) | |
| start = time.time() | |
| predictions_dict = model.predict( | |
| instances_dict={ | |
| "instances": [ | |
| { | |
| "filepath": image_path, | |
| } | |
| ] | |
| } | |
| ) | |
| end = time.time() | |
| print(f"Inference time: {end - start:.2f}s") | |
| # validate | |
| validate_model_output(predictions_dict) | |
| # save JSON | |
| with open(f"{image_id}.json", "w") as f: | |
| json.dump(predictions_dict, f, indent=4) | |
| # save YOLO txt | |
| save_yolo_annotations(image_path, predictions_dict, txt_path) | |
| # upload to HF bucket | |
| upload_to_bucket(image_path, txt_path, image_id) | |
| # visualize | |
| annotated = draw_predictions(image_path, predictions_dict) | |
| return annotated, json.dumps(predictions_dict, indent=4) | |
| # ------------------------------------------------------ | |
| # GRADIO UI | |
| # ------------------------------------------------------ | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Image(label="Detection Output"), | |
| gr.JSON(label="Model Output") | |
| ], | |
| title="Wildlife Detector + SpeciesNet", | |
| description="Upload wildlife image → detect + classify + save to Hugging Face bucket" | |
| ) | |
| iface.launch() |