Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- app.py +19 -0
- counting_model.pt +3 -0
- fish_feeding.py +164 -0
- length_model.pt +3 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
from fish_feeding import FishFeeding
|
| 4 |
+
|
| 5 |
+
model = FishFeeding()
|
| 6 |
+
model.load_models()
|
| 7 |
+
|
| 8 |
+
def fish_feeding(images):
|
| 9 |
+
for i, img in enumerate(images):
|
| 10 |
+
images[i] = np.array(img, dtype=np.uint8)
|
| 11 |
+
|
| 12 |
+
total_feed, times = model.final_fish_feed(images)
|
| 13 |
+
return {"total_feed": total_feed, "times": times}
|
| 14 |
+
|
| 15 |
+
inputs = gr.Image(type='numpy', label="Upload fish images")
|
| 16 |
+
outputs = gr.JSON(label="Fish Feeding Results")
|
| 17 |
+
|
| 18 |
+
app = gr.Interface(fish_feeding, inputs=inputs, outputs=outputs, title="Fish Feeding Predictor")
|
| 19 |
+
app.launch()
|
counting_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9dc64af07f0ed80fe8f9c9b9d286353136a649accf794c924af6b8832ae07a7
|
| 3 |
+
size 6238297
|
fish_feeding.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
|
| 7 |
+
class FishFeeding:
|
| 8 |
+
|
| 9 |
+
def __init__(self, focal_length: float = 27.4) -> None:
|
| 10 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
self.collected_lengths = []
|
| 12 |
+
self.focal_length = focal_length
|
| 13 |
+
self.final_weight = None
|
| 14 |
+
self.length_model_name = "length_model.pt"
|
| 15 |
+
self.depth_model_name = "vinvino02/glpn-nyu"
|
| 16 |
+
self.counting_model_name = "counting_model.pt"
|
| 17 |
+
|
| 18 |
+
def load_models(self) -> None:
|
| 19 |
+
self.fish_keypoints_model = YOLO(self.length_model_name)
|
| 20 |
+
self.depth_model = pipeline(task="depth-estimation", model=self.depth_model_name, device=self.device)
|
| 21 |
+
self.fish_detection_model = YOLO(self.counting_model_name)
|
| 22 |
+
|
| 23 |
+
def predict_fish_length(self, frame):
|
| 24 |
+
image_obj = Image.fromarray(frame)
|
| 25 |
+
image_obj = image_obj.resize((640, 640)) # Adjust size as per requirement
|
| 26 |
+
depth = self.depth_model(image_obj)
|
| 27 |
+
depth = depth["predicted_depth"]
|
| 28 |
+
depth = np.array(depth).squeeze()
|
| 29 |
+
|
| 30 |
+
results = self.fish_detection_model(frame)[0]
|
| 31 |
+
if (results.keypoints == None):
|
| 32 |
+
raise ValueError("No fish detected in the image")
|
| 33 |
+
keypoints = results.keypoints.xyn[0].detach().cpu().numpy()
|
| 34 |
+
head = keypoints[0]
|
| 35 |
+
back = keypoints[1]
|
| 36 |
+
belly = keypoints[2]
|
| 37 |
+
tail = keypoints[3]
|
| 38 |
+
|
| 39 |
+
depth_w, depth_h = depth.shape[:2]
|
| 40 |
+
|
| 41 |
+
head_x = int(head[0] * depth_w)
|
| 42 |
+
head_y = int(head[1] * depth_h)
|
| 43 |
+
tail_x = int(tail[0] * depth_w)
|
| 44 |
+
tail_y = int(tail[1] * depth_h)
|
| 45 |
+
|
| 46 |
+
back_x = int(back[0] * depth_w)
|
| 47 |
+
back_y = int(back[1] * depth_h)
|
| 48 |
+
belly_x = int(belly[0] * depth_w)
|
| 49 |
+
belly_y = int(belly[1] * depth_h)
|
| 50 |
+
|
| 51 |
+
head_depth = depth[head_y, head_x]
|
| 52 |
+
tail_depth = depth[tail_y, tail_x]
|
| 53 |
+
|
| 54 |
+
fish_length = (
|
| 55 |
+
np.sqrt(
|
| 56 |
+
(head_x * head_depth - tail_x * tail_depth) ** 2
|
| 57 |
+
+ (head_y * head_depth - tail_y * tail_depth) ** 2
|
| 58 |
+
)
|
| 59 |
+
/ self.focal_length
|
| 60 |
+
)
|
| 61 |
+
# girth = (
|
| 62 |
+
# np.sqrt(
|
| 63 |
+
# (back_x * head_depth - belly_x * tail_depth) ** 2
|
| 64 |
+
# + (back_y * head_depth - belly_y * tail_depth) ** 2
|
| 65 |
+
# )
|
| 66 |
+
# / self.focal_length
|
| 67 |
+
# )
|
| 68 |
+
return fish_length
|
| 69 |
+
|
| 70 |
+
# def videocapture(self):
|
| 71 |
+
# cap = cv2.VideoCapture(self.video_path)
|
| 72 |
+
# assert cap.isOpened(), "Error reading video file"
|
| 73 |
+
# while True:
|
| 74 |
+
# ret, frame = cap.read()
|
| 75 |
+
# if not ret:
|
| 76 |
+
# break
|
| 77 |
+
# output = self.predict_fish_length(frame)
|
| 78 |
+
# self.collected_lengths.append(output)
|
| 79 |
+
# cap.release()
|
| 80 |
+
# return self.collected_lengths
|
| 81 |
+
|
| 82 |
+
def get_average_weight(self):
|
| 83 |
+
if not self.collected_lengths:
|
| 84 |
+
return 0
|
| 85 |
+
length_average = sum(self.collected_lengths) / len(self.collected_lengths)
|
| 86 |
+
final_weight = 0.014 * length_average ** 3.02
|
| 87 |
+
return final_weight
|
| 88 |
+
|
| 89 |
+
def fish_counting(self, images):
|
| 90 |
+
counting_output = 0
|
| 91 |
+
for im0 in images:
|
| 92 |
+
tracks = self.fish_detection_model(im0)
|
| 93 |
+
counting_output = max(counting_output, len(tracks))
|
| 94 |
+
|
| 95 |
+
return counting_output
|
| 96 |
+
|
| 97 |
+
def final_fish_feed(self, images: list):
|
| 98 |
+
for image in images:
|
| 99 |
+
try:
|
| 100 |
+
output = self.predict_fish_length(image)
|
| 101 |
+
except ValueError:
|
| 102 |
+
continue
|
| 103 |
+
self.collected_lengths.append(output)
|
| 104 |
+
|
| 105 |
+
average_weight = self.get_average_weight()
|
| 106 |
+
if 0 <= average_weight <= 50:
|
| 107 |
+
feed, times = 3.3, 2
|
| 108 |
+
elif 50 < average_weight <= 100:
|
| 109 |
+
feed, times = 4.8, 2
|
| 110 |
+
elif 100 < average_weight <= 250:
|
| 111 |
+
feed, times = 5.8, 2
|
| 112 |
+
elif 250 < average_weight <= 500:
|
| 113 |
+
feed, times = 8.4, 2
|
| 114 |
+
elif 500 < average_weight <= 750:
|
| 115 |
+
feed, times = 9.4, 1
|
| 116 |
+
elif 750 < average_weight <= 1000:
|
| 117 |
+
feed, times = 10.5, 1
|
| 118 |
+
elif 1000 < average_weight <= 1500:
|
| 119 |
+
feed, times = 11.0, 1
|
| 120 |
+
else:
|
| 121 |
+
feed, times = 12.0, 1
|
| 122 |
+
|
| 123 |
+
fish_count = self.fish_counting(images)
|
| 124 |
+
total_feed = feed * fish_count
|
| 125 |
+
return total_feed, times
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# if __name__ == "__main__":
|
| 129 |
+
# to_collect = 6
|
| 130 |
+
# collected = []
|
| 131 |
+
# video_path = "object_counting.mp4"
|
| 132 |
+
# cap = cv2.VideoCapture(video_path)
|
| 133 |
+
|
| 134 |
+
# fish_feeding = FishFeeding()
|
| 135 |
+
# fish_feeding.load_models()
|
| 136 |
+
|
| 137 |
+
# d = {"images": []}
|
| 138 |
+
|
| 139 |
+
# while True:
|
| 140 |
+
# ret, frame = cap.read()
|
| 141 |
+
# if not ret:
|
| 142 |
+
# break
|
| 143 |
+
|
| 144 |
+
# if len(collected) == to_collect:
|
| 145 |
+
# total_feed, times = fish_feeding.final_fish_feed(collected)
|
| 146 |
+
# print(f"Total feed: {total_feed}, Feed times: {times}")
|
| 147 |
+
# collected = []
|
| 148 |
+
|
| 149 |
+
# break
|
| 150 |
+
|
| 151 |
+
# collected.append(frame)
|
| 152 |
+
# d["images"].append(frame.tolist())
|
| 153 |
+
|
| 154 |
+
# if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 155 |
+
# break
|
| 156 |
+
|
| 157 |
+
# cap.release()
|
| 158 |
+
# cv2.destroyAllWindows()
|
| 159 |
+
|
| 160 |
+
# # save d to json file
|
| 161 |
+
# import json
|
| 162 |
+
# with open("data.json", "w") as f:
|
| 163 |
+
# json.dump(d, f)
|
| 164 |
+
|
length_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:208dbde300963cfe83f5f4a9fdfd7a91e640d0410d01ce4e70c96d440cbc03d1
|
| 3 |
+
size 6403287
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pillow
|
| 2 |
+
ultralytics
|
| 3 |
+
transformers
|
| 4 |
+
fastapi
|
| 5 |
+
dill==0.3.8
|
| 6 |
+
gradio
|
| 7 |
+
torch
|