achouffe commited on
Commit
eccd5c5
·
verified ·
1 Parent(s): e8578a5

feat: initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10.12
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app to showcase the pyronear model for early forest fire detection.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ from ultralytics import YOLO
12
+
13
+
14
+ def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
15
+ """
16
+ Turn a BGR numpy array into a RGB numpy array when the array `a` represents
17
+ an image.
18
+ """
19
+ return a[:, :, ::-1]
20
+
21
+
22
+ def prediction_to_str(yolo_prediction) -> str:
23
+ """
24
+ Turn the yolo_prediction into a human friendly string.
25
+ """
26
+ boxes = yolo_prediction.boxes
27
+ classes = boxes.cls.cpu().numpy().astype(np.int8)
28
+ n_bear = len([c for c in classes if c == 0])
29
+ n_soft_coral = len([c for c in classes if c == 1])
30
+
31
+ return f"""{len(boxes.conf)} bear detected! Trigger the bear repellent 🐻"""
32
+
33
+
34
+ def predict(model: YOLO, pil_image: Image.Image) -> Tuple[Image.Image, str]:
35
+ """
36
+ Main interface function that runs the model on the provided pil_image and
37
+ returns the exepected tuple to populate the gradio interface.
38
+
39
+ Args:
40
+ model (YOLO): Loaded ultralytics YOLO model.
41
+ pil_image (PIL): image to run inference on.
42
+
43
+ Returns:
44
+ pil_image_with_prediction (PIL): image with prediction from the model.
45
+ raw_prediction_str (str): string representing the raw prediction from the
46
+ model.
47
+ """
48
+ predictions = model(pil_image)
49
+ prediction = predictions[0]
50
+ pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
51
+ raw_prediction_str = prediction_to_str(prediction)
52
+
53
+ return (pil_image_with_prediction, raw_prediction_str)
54
+
55
+
56
+ def examples(dir_examples: Path) -> list[Path]:
57
+ """
58
+ List the images from the dir_examples directory.
59
+
60
+ Returns:
61
+ filepaths (list[Path]): list of image filepaths.
62
+ """
63
+ return list(dir_examples.glob("*.jpg"))
64
+
65
+
66
+ def load_model(filepath_weights: Path) -> YOLO:
67
+ """
68
+ Load the YOLO model given the filepath_weights.
69
+ """
70
+ return YOLO(filepath_weights)
71
+
72
+
73
+ # Main Gradio interface
74
+
75
+ MODEL_FILEPATH_WEIGHTS = Path("data/model/weights/model.pt")
76
+ DIR_EXAMPLES = Path("data/images/")
77
+ DEFAULT_IMAGE_INDEX = 1
78
+
79
+ with gr.Blocks() as demo:
80
+ model = load_model(MODEL_FILEPATH_WEIGHTS)
81
+ image_filepaths = examples(dir_examples=DIR_EXAMPLES)
82
+ default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
83
+ input = gr.Image(
84
+ value=default_value_input,
85
+ type="pil",
86
+ label="input image",
87
+ sources=["upload", "clipboard"],
88
+ )
89
+ output_image = gr.Image(type="pil", label="model prediction")
90
+ output_raw = gr.Text(label="raw prediction")
91
+
92
+ fn = lambda pil_image: predict(model=model, pil_image=pil_image)
93
+ gr.Interface(
94
+ title="ML model for detecting bears from camera traps 🐻",
95
+ fn=fn,
96
+ inputs=input,
97
+ outputs=[output_image, output_raw],
98
+ examples=image_filepaths,
99
+ flagging_mode="never",
100
+ )
101
+
102
+ demo.launch()
data/images/03080008.jpg ADDED

Git LFS Details

  • SHA256: de084fa768ef0ad3b7c1359ce5770055532bfde0e38577a01905a9e7b48a7f5a
  • Pointer size: 131 Bytes
  • Size of remote file: 967 kB
data/images/03270018.jpg ADDED

Git LFS Details

  • SHA256: 621c2f6d725a1c69eb08c872c85cd302b7072dcca8795a3b0997634ed6b1c29e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
data/images/04070249.jpg ADDED

Git LFS Details

  • SHA256: a3dec784b117fc3a533f7e9c0a37497a6cb19b370b2241ca944fac3bb0925d44
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
data/images/I__00010.jpg ADDED

Git LFS Details

  • SHA256: 6375538b9a4188e8f92fc0d34ca253c1252d92e670f910f7c259d2200d476d2d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.95 MB
data/images/image1.jpg ADDED

Git LFS Details

  • SHA256: 2ac1a08760aa48a8e3164d450835c89d6a549131b64f30013484e3e385b8890f
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
data/model/weights/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47f41ed8d0867288d00d1ae823cb05ea90f412e81a66dd39939620f211a3eecd
3
+ size 6346222
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ultralytics==8.3.*
2
+ gradio==5.4.*