Yapp99 commited on
Commit
9dc317c
·
1 Parent(s): f0221f7
Files changed (2) hide show
  1. S1_YoloTimber.py +177 -0
  2. app.py +149 -0
S1_YoloTimber.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch
3
+ import hubconf
4
+ import os
5
+ from torch import nn, Tensor
6
+ import torch.nn.functional as F
7
+ import cv2
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ model_name = "yolov6s"
12
+
13
+ from yolov6.models.yolo import Model as YoloModel
14
+ from yolov6.utils.config import Config
15
+ config = Config.fromfile(f"configs/base/{model_name}_base_finetune.py")
16
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ class YoloBackbone(YoloModel):
19
+ def __init__(self, config, num_classes, device):
20
+ super().__init__(config, num_classes=num_classes)
21
+
22
+ self.to(device)
23
+ self.train()
24
+
25
+ def forward(self, x:Tensor) -> Tensor:
26
+ # x = self.backbone.forward(x)
27
+ # x = self.neck.forward(x)
28
+ # x = self.detect.forward(x)
29
+ # _,_,_,x = x
30
+ return x
31
+
32
+ class Interpreter(nn.Module):
33
+ def __init__(self,
34
+ class_count:int,
35
+ sample_yolo_output,
36
+ device,
37
+ ):
38
+ super().__init__()
39
+
40
+ c = 32
41
+
42
+ self.train()
43
+ self._conv1 = nn.Conv2d(in_channels= 3, out_channels= 2*c, kernel_size=5, padding=2)
44
+ self._conv2 = nn.Conv2d(in_channels= 2*c, out_channels= 4*c, kernel_size=5, padding=2)
45
+ self._conv3 = nn.Conv2d(in_channels= 4*c, out_channels= 8*c, kernel_size=5, padding=2)
46
+ self._conv4 = nn.Conv2d(in_channels= 8*c, out_channels=16*c, kernel_size=3, padding=1)
47
+ self._conv5 = nn.Conv2d(in_channels=16*c, out_channels=32*c, kernel_size=3, padding=1)
48
+ self._conv6 = nn.Conv2d(in_channels=32*c, out_channels=64*c, kernel_size=3, padding=1)
49
+
50
+ self._linear_size = self.calc_linear(sample_yolo_output)
51
+ print(self._linear_size)
52
+
53
+ self._fc1 = nn.Linear(self._linear_size,512)
54
+ self._fc2 = nn.Linear(512, class_count)
55
+
56
+ self.to(device)
57
+ self.device = device
58
+ self.training = True
59
+ self.train()
60
+
61
+ def calc_linear(self, sample_yolo_output) -> int:
62
+ x = self.convs(sample_yolo_output.to('cpu'))
63
+ return x.shape[-1]
64
+
65
+ def convs(self, x:Tensor) -> Tensor:
66
+ x = F.max_pool2d(F.relu(self._conv1(x)), (2,2))
67
+ x = F.max_pool2d(F.relu(self._conv2(x)), (2,2))
68
+ x = F.max_pool2d(F.relu(self._conv3(x)), (2,2))
69
+ x = F.max_pool2d(F.relu(self._conv4(x)), (2,2))
70
+ x = F.max_pool2d(F.relu(self._conv5(x)), (2,2))
71
+ x = F.max_pool2d(F.relu(self._conv6(x)), (2,2))
72
+ x = torch.flatten(x,1)
73
+ return x
74
+
75
+ def fc(self, x:Tensor) -> Tensor:
76
+ x = F.relu(self._fc1(x))
77
+ # x = F.relu(self._fc2(x))
78
+ x = self._fc2(x)
79
+ return x
80
+
81
+ def forward(self, x:list[Tensor]) -> Tensor:
82
+ x = self.convs(x)
83
+ x = self.fc(x)
84
+ return x
85
+
86
+ import patchify
87
+ from torchvision import transforms
88
+
89
+ class YoloTimber(nn.Module):
90
+ def __init__(self,
91
+ image_size: tuple[int,int],
92
+ yolo_model: YoloBackbone,
93
+ interpreter: Interpreter,
94
+ ):
95
+ super().__init__()
96
+ self.device = interpreter.device
97
+ self.yolo_model = yolo_model
98
+ self.image_size = image_size
99
+ self.interpreter = interpreter
100
+
101
+ def predict(self, img_path:str) -> Tensor:
102
+ img = cv2.imread(img_path)
103
+ img = Image.fromarray(img)
104
+ img = transforms.ToTensor()(img)
105
+ img = torchvision.transforms.Resize(self.image_size)(img)
106
+ img = img[None]
107
+ img = img.to(self.device)
108
+
109
+ preds = self.forward(img)
110
+ _, preds = torch.max(preds,1)
111
+ return preds
112
+
113
+ def forward(self, x:Tensor) -> Tensor:
114
+ x = self.yolo_model(x)
115
+ x = self.interpreter(x)
116
+ return x
117
+
118
+ def predict_large_image(self,
119
+ img: np.ndarray,
120
+ patch_size:int = 816,
121
+ ) -> Tensor:
122
+
123
+ L = patch_size
124
+ patches = patchify.patchify(img,(L,L,3),L)
125
+ w,h,_ = patches.shape[:3]
126
+ patches = patches.reshape(w*h,*patches.shape[3:]).transpose((0,3,1,2))
127
+
128
+ patches = torch.from_numpy(patches)
129
+
130
+ patches = patches.float() / 255
131
+ patches = transforms.Resize(self.image_size)(patches)
132
+ patches = patches.to(self.device)
133
+
134
+ preds = self.forward(patches)
135
+ _, preds = torch.max(preds,1)
136
+ preds = torch.mode(preds, 0).values
137
+ return preds
138
+
139
+ class_count = 41
140
+
141
+ def build_backbone(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> YoloBackbone:
142
+ return YoloBackbone(
143
+ config = config,
144
+ num_classes=class_count,
145
+ device = device
146
+ )
147
+
148
+ def build_interpreter(img_size=(640,640),
149
+ yolo_model = None,
150
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
151
+ ) -> Interpreter:
152
+ img_size = list(img_size)
153
+ if yolo_model == None:
154
+ yolo_model = build_backbone(device)
155
+
156
+ x = torch.randn([3]+img_size).view([-1,3]+img_size).to(device)
157
+ x = yolo_model(x)
158
+
159
+ return Interpreter(class_count=class_count, sample_yolo_output=x, device=device)
160
+
161
+ def build_model(img_size = (640,640),
162
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
163
+ ) -> YoloTimber:
164
+ yolo_model=build_backbone(device)
165
+ return YoloTimber(yolo_model=yolo_model,
166
+ image_size=img_size,
167
+ interpreter=build_interpreter(img_size, yolo_model, device))
168
+
169
+ if __name__ == "__main__":
170
+ model = build_model(img_size=(320,320))
171
+ DATA_DIR = "data/image/test"
172
+ dir = os.listdir(DATA_DIR)[0]
173
+ img_name = os.listdir(f"{DATA_DIR}/{dir}")[0]
174
+ img_path = f"{DATA_DIR}/{dir}/{img_name}"
175
+
176
+ out = model.predict_large_image(img_path)
177
+ print(out)
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+ import os
3
+ import torch
4
+ from S1_YoloTimber import YoloTimber
5
+ import gradio as gr
6
+ import numpy as np
7
+ import cv2
8
+ import pandas as pd
9
+
10
+ MODEL_LINK = "https://drive.google.com/file/d/1XMdyxlKg7iliN6ekJVn9v4o6HJCJ-ASb/view?usp=drive_link"
11
+ MODEL_PATH = "model.pt"
12
+
13
+ if not os.path.exists(MODEL_PATH):
14
+ print("Downloading model . . . ")
15
+ gdown.download(MODEL_LINK,MODEL_PATH,fuzzy=True)
16
+
17
+ model:YoloTimber = torch.load(MODEL_PATH)
18
+ model.image_size = (320,320)
19
+
20
+ def listdir_full(path: str) -> list[str]:
21
+ return [f"{path}/{p}" for p in os.listdir(path)]
22
+
23
+ SAMPLE_DIR = "data/image/test_full"
24
+ labels = os.listdir(SAMPLE_DIR)
25
+
26
+ class History():
27
+ cols = ["Image", "Prediction"]
28
+
29
+ def __init__(self, img, name) -> None:
30
+ self.img = resize_image(img)
31
+ self.name = name
32
+
33
+ MAX_IMG_LEN = 160
34
+ def resize_image(img):
35
+ h, w, _ = img.shape
36
+
37
+ if w > h:
38
+ w1 = MAX_IMG_LEN
39
+ h1 = int(h/w * MAX_IMG_LEN)
40
+ else:
41
+ h1 = MAX_IMG_LEN
42
+ w1 = int(w/h * MAX_IMG_LEN)
43
+ return cv2.resize(img,(w1,h1))
44
+
45
+ PD_COLS=["image","predicted species"]
46
+ MAX_HISTORY = 10
47
+
48
+ def classify(image: np.array, history):
49
+ if history == None: history = []
50
+
51
+ with torch.no_grad():
52
+ pred = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR)).item()
53
+ pred = labels[pred]
54
+
55
+ history += [(resize_image(image), pred)]
56
+ hist = history[-MAX_HISTORY:]
57
+
58
+ return pred, *toggle_history_components(hist), history
59
+
60
+ def toggle_history_components(history: list[History]):
61
+ n_hidden = MAX_HISTORY - len(history)
62
+ images, names = list(zip(*history))
63
+
64
+ components = [gr.Image(x, visible=True) for x in images]
65
+ components += [gr.Image(visible=False)] * n_hidden
66
+ components += [gr.Markdown(x, visible=True) for x in names]
67
+ components += [gr.Markdown(visible=False)] * n_hidden
68
+ return components
69
+
70
+ def classification_tab():
71
+ with gr.Row():
72
+ with gr.Column():
73
+ image = gr.Image()
74
+ with gr.Row():
75
+ submit = gr.Button("Submit", variant='primary')
76
+ clear = gr.ClearButton(image)
77
+ pred = gr.Textbox(label="Prediction")
78
+
79
+ return image, submit, clear, pred
80
+
81
+ MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
82
+
83
+ def sample_tab(image_input, tabs):
84
+
85
+ def choose_image(image):
86
+ return gr.Image(image), gr.Tabs(selected=0)
87
+
88
+ def refresh_samples(species):
89
+ images = listdir_full(f"{SAMPLE_DIR}/{species}")
90
+ n_hidden = MAX_SAMPLE_COUNT-len(images)
91
+
92
+ components = [gr.Image(i,visible=True) for i in images]
93
+ components += [gr.Image(visible=False)] * n_hidden
94
+ components += [gr.Button(visible=True) for _ in images]
95
+ components += [gr.Button(visible=False)] * n_hidden
96
+ return components
97
+
98
+ dropdown = gr.Dropdown(labels, label="Species", value="Select a Species")
99
+
100
+ images = []
101
+ buttons = []
102
+
103
+ def sample_panel():
104
+ with gr.Column():
105
+ image = gr.Image(visible=False ,interactive=False, min_width=1)
106
+ select = gr.Button("Submit", variant='primary', visible=False)
107
+
108
+ images.append(image)
109
+ buttons.append(select)
110
+ select.click(choose_image, image, [image_input, tabs])
111
+
112
+ with gr.Row(): [sample_panel() for _ in range(MAX_SAMPLE_COUNT)]
113
+
114
+ dropdown.change(refresh_samples, dropdown, images+buttons)
115
+ return
116
+
117
+ def history_tab():
118
+ history_imgs = []
119
+ history_names = []
120
+ with gr.Row():
121
+ gr.Markdown("# Image")
122
+ gr.Markdown("# Species")
123
+ gr.Markdown("")
124
+
125
+ with gr.Column():
126
+ for _ in range(MAX_HISTORY):
127
+ with gr.Row():
128
+ history_imgs.append(gr.Image(height=200,visible=False))
129
+ history_names.append(gr.Markdown("A",visible=False))
130
+ gr.Markdown("")
131
+
132
+ return history_imgs + history_names
133
+
134
+ with gr.Blocks() as demo:
135
+ history = gr.State([])
136
+ with gr.Tabs() as tabs:
137
+ with gr.Tab("Classification", id=0):
138
+ image, submit, clear, pred = classification_tab()
139
+
140
+ with gr.Tab("Samples", id=1):
141
+ sample_tab(image, tabs)
142
+
143
+ with gr.Tab("History", id=2):
144
+ table_contents = history_tab()
145
+
146
+ # history = gr.Gallery(interactive=False)
147
+ submit.click(classify,[image, history],[pred, *table_contents, history])
148
+
149
+ demo.launch()