SpeciesNet / app.py
codewithRiz's picture
Update app.py
f9b7aa9 verified
Raw
History Blame Contribute Delete
5.86 kB
from IPython.display import display, JSON
import matplotlib.pyplot as plt
from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet
import numpy as np
import time
import gradio as gr
import json
import cv2
import os
from huggingface_hub import batch_bucket_files
# ------------------------------------------------------
# HF TOKEN (IMPORTANT)
# ------------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN") # set in Spaces secrets
BUCKET_ID = "codewithRiz/Buck_data_storage"
# ------------------------------------------------------
# LOAD MODEL
# ------------------------------------------------------
print("Default SpeciesNet model:", DEFAULT_MODEL)
print("Supported SpeciesNet models:", SUPPORTED_MODELS)
model = SpeciesNet(DEFAULT_MODEL)
# ------------------------------------------------------
# VALIDATION
# ------------------------------------------------------
def validate_predictions_structure(pred):
required_keys = ["filepath", "detections", "classifications"]
for key in required_keys:
if key not in pred:
raise ValueError(f"Missing key '{key}'")
if not isinstance(pred["detections"], list):
raise ValueError("detections must be list")
cls = pred["classifications"]
if "classes" not in cls or "scores" not in cls:
raise ValueError("classification format invalid")
return True
def validate_model_output(predictions_dict):
if "predictions" not in predictions_dict:
raise ValueError("Missing predictions")
for pred in predictions_dict["predictions"]:
validate_predictions_structure(pred)
# ------------------------------------------------------
# SAVE YOLO TXT FORMAT
# ------------------------------------------------------
def save_yolo_annotations(image_path, predictions_dict, txt_path):
"""
Format:
class_name x_center y_center width height (normalized)
"""
img = cv2.imread(image_path)
h, w, _ = img.shape
lines = []
for pred in predictions_dict.get("predictions", []):
detections = pred.get("detections", [])
classes = pred.get("classifications", {}).get("classes", [])
if not classes:
continue
class_name = classes[0].split(";")[-1]
for det in detections:
x, y, bw, bh = det["bbox"]
x_center = x + bw / 2
y_center = y + bh / 2
lines.append(f"{class_name} {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}")
with open(txt_path, "w") as f:
f.write("\n".join(lines))
# ------------------------------------------------------
# UPLOAD TO BUCKET (SAFE)
# ------------------------------------------------------
def upload_to_bucket(image_path, txt_path, image_id):
if HF_TOKEN is None:
print("⚠ HF_TOKEN missing → skipping upload")
return
try:
batch_bucket_files(
BUCKET_ID,
add=[
(image_path, f"images/{image_id}.jpg"),
(txt_path, f"labels/{image_id}.txt"),
],
token=HF_TOKEN
)
print("✅ Uploaded to bucket")
except Exception as e:
print("⚠ Upload failed:", str(e))
# ------------------------------------------------------
# DRAW BOXES
# ------------------------------------------------------
def draw_predictions(image_path, predictions_dict):
img = cv2.imread(image_path)
h, w, _ = img.shape
for pred in predictions_dict.get("predictions", []):
detections = pred.get("detections", [])
cls = pred.get("classifications", {})
classes = cls.get("classes", [])
scores = cls.get("scores", [])
if not classes:
continue
top_class = classes[0].split(";")[-1]
top_score = scores[0]
for det in detections:
x, y, bw, bh = det["bbox"]
x1 = int(x * w)
y1 = int(y * h)
x2 = int((x + bw) * w)
y2 = int((y + bh) * h)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{top_class} {top_score:.2f}"
cv2.putText(img, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 2)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------------------------------------------
# INFERENCE FUNCTION
# ------------------------------------------------------
def inference(image):
image_id = str(int(time.time()))
image_path = f"{image_id}.jpg"
txt_path = f"{image_id}.txt"
image.save(image_path)
start = time.time()
predictions_dict = model.predict(
instances_dict={
"instances": [
{
"filepath": image_path,
}
]
}
)
end = time.time()
print(f"Inference time: {end - start:.2f}s")
# validate
validate_model_output(predictions_dict)
# save JSON
with open(f"{image_id}.json", "w") as f:
json.dump(predictions_dict, f, indent=4)
# save YOLO txt
save_yolo_annotations(image_path, predictions_dict, txt_path)
# upload to HF bucket
upload_to_bucket(image_path, txt_path, image_id)
# visualize
annotated = draw_predictions(image_path, predictions_dict)
return annotated, json.dumps(predictions_dict, indent=4)
# ------------------------------------------------------
# GRADIO UI
# ------------------------------------------------------
iface = gr.Interface(
fn=inference,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(label="Detection Output"),
gr.JSON(label="Model Output")
],
title="Wildlife Detector + SpeciesNet",
description="Upload wildlife image → detect + classify + save to Hugging Face bucket"
)
iface.launch()