File size: 5,858 Bytes
f9b7aa9
 
 
 
6ca6607
f9b7aa9
6ca6607
 
f9b7aa9
ef15c37
 
6ca6607
66b74f8
f9b7aa9
66b74f8
f9b7aa9
 
66b74f8
6ca6607
f9b7aa9
 
 
 
 
 
66b74f8
6ca6607
f9b7aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca6607
 
f9b7aa9
6ca6607
f9b7aa9
 
 
 
 
9942747
6ca6607
66b74f8
 
 
 
f9b7aa9
66b74f8
 
 
 
 
 
 
 
 
 
 
 
 
 
f9b7aa9
6ca6607
66b74f8
 
 
 
 
f9b7aa9
66b74f8
f9b7aa9
 
66b74f8
 
 
 
 
 
f9b7aa9
 
 
 
66b74f8
 
f9b7aa9
66b74f8
 
f9b7aa9
66b74f8
 
 
f9b7aa9
66b74f8
f9b7aa9
 
 
066e161
f9b7aa9
 
 
066e161
f9b7aa9
 
066e161
f9b7aa9
 
066e161
f9b7aa9
 
9942747
f9b7aa9
 
6ca6607
f9b7aa9
 
 
 
6ca6607
f9b7aa9
ef15c37
f9b7aa9
 
 
 
6ca6607
f9b7aa9
6ca6607
 
ef15c37
f9b7aa9
6ca6607
f9b7aa9
ef15c37
f9b7aa9
 
 
9942747
f9b7aa9
9942747
f9b7aa9
6ca6607
f9b7aa9
 
 
 
 
 
 
 
 
066e161
f9b7aa9
 
6ca6607
f9b7aa9
 
6ca6607
f9b7aa9
 
 
6ca6607
f9b7aa9
 
6ca6607
f9b7aa9
 
6ca6607
f9b7aa9
 
6ca6607
f9b7aa9
6ca6607
 
 
66b74f8
6ca6607
f9b7aa9
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from IPython.display import display, JSON
import matplotlib.pyplot as plt
from speciesnet import DEFAULT_MODEL, SUPPORTED_MODELS, SpeciesNet
import numpy as np
import time
import gradio as gr
import json
import cv2
import os

from huggingface_hub import batch_bucket_files

# ------------------------------------------------------
# HF TOKEN (IMPORTANT)
# ------------------------------------------------------
HF_TOKEN = os.environ.get("HF_TOKEN")  # set in Spaces secrets

BUCKET_ID = "codewithRiz/Buck_data_storage"

# ------------------------------------------------------
# LOAD MODEL
# ------------------------------------------------------
print("Default SpeciesNet model:", DEFAULT_MODEL)
print("Supported SpeciesNet models:", SUPPORTED_MODELS)

model = SpeciesNet(DEFAULT_MODEL)

# ------------------------------------------------------
# VALIDATION
# ------------------------------------------------------
def validate_predictions_structure(pred):
    required_keys = ["filepath", "detections", "classifications"]

    for key in required_keys:
        if key not in pred:
            raise ValueError(f"Missing key '{key}'")

    if not isinstance(pred["detections"], list):
        raise ValueError("detections must be list")

    cls = pred["classifications"]
    if "classes" not in cls or "scores" not in cls:
        raise ValueError("classification format invalid")

    return True


def validate_model_output(predictions_dict):
    if "predictions" not in predictions_dict:
        raise ValueError("Missing predictions")

    for pred in predictions_dict["predictions"]:
        validate_predictions_structure(pred)

# ------------------------------------------------------
# SAVE YOLO TXT FORMAT
# ------------------------------------------------------
def save_yolo_annotations(image_path, predictions_dict, txt_path):
    """
    Format:
    class_name x_center y_center width height (normalized)
    """

    img = cv2.imread(image_path)
    h, w, _ = img.shape

    lines = []

    for pred in predictions_dict.get("predictions", []):
        detections = pred.get("detections", [])
        classes = pred.get("classifications", {}).get("classes", [])

        if not classes:
            continue

        class_name = classes[0].split(";")[-1]

        for det in detections:
            x, y, bw, bh = det["bbox"]

            x_center = x + bw / 2
            y_center = y + bh / 2

            lines.append(f"{class_name} {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}")

    with open(txt_path, "w") as f:
        f.write("\n".join(lines))


# ------------------------------------------------------
# UPLOAD TO BUCKET (SAFE)
# ------------------------------------------------------
def upload_to_bucket(image_path, txt_path, image_id):
    if HF_TOKEN is None:
        print("⚠ HF_TOKEN missing → skipping upload")
        return

    try:
        batch_bucket_files(
            BUCKET_ID,
            add=[
                (image_path, f"images/{image_id}.jpg"),
                (txt_path, f"labels/{image_id}.txt"),
            ],
            token=HF_TOKEN
        )
        print("✅ Uploaded to bucket")

    except Exception as e:
        print("⚠ Upload failed:", str(e))


# ------------------------------------------------------
# DRAW BOXES
# ------------------------------------------------------
def draw_predictions(image_path, predictions_dict):
    img = cv2.imread(image_path)
    h, w, _ = img.shape

    for pred in predictions_dict.get("predictions", []):
        detections = pred.get("detections", [])
        cls = pred.get("classifications", {})

        classes = cls.get("classes", [])
        scores = cls.get("scores", [])

        if not classes:
            continue

        top_class = classes[0].split(";")[-1]
        top_score = scores[0]

        for det in detections:
            x, y, bw, bh = det["bbox"]

            x1 = int(x * w)
            y1 = int(y * h)
            x2 = int((x + bw) * w)
            y2 = int((y + bh) * h)

            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)

            label = f"{top_class} {top_score:.2f}"
            cv2.putText(img, label, (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                        (255, 255, 255), 2)

    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


# ------------------------------------------------------
# INFERENCE FUNCTION
# ------------------------------------------------------
def inference(image):

    image_id = str(int(time.time()))
    image_path = f"{image_id}.jpg"
    txt_path = f"{image_id}.txt"

    image.save(image_path)

    start = time.time()

    predictions_dict = model.predict(
        instances_dict={
            "instances": [
                {
                    "filepath": image_path,
                }
            ]
        }
    )

    end = time.time()
    print(f"Inference time: {end - start:.2f}s")

    # validate
    validate_model_output(predictions_dict)

    # save JSON
    with open(f"{image_id}.json", "w") as f:
        json.dump(predictions_dict, f, indent=4)

    # save YOLO txt
    save_yolo_annotations(image_path, predictions_dict, txt_path)

    # upload to HF bucket
    upload_to_bucket(image_path, txt_path, image_id)

    # visualize
    annotated = draw_predictions(image_path, predictions_dict)

    return annotated, json.dumps(predictions_dict, indent=4)


# ------------------------------------------------------
# GRADIO UI
# ------------------------------------------------------
iface = gr.Interface(
    fn=inference,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(label="Detection Output"),
        gr.JSON(label="Model Output")
    ],
    title="Wildlife Detector + SpeciesNet",
    description="Upload wildlife image → detect + classify + save to Hugging Face bucket"
)

iface.launch()