miojizzy commited on
Commit
a22da3a
·
1 Parent(s): 9336d9f

Upload 9 files

Browse files
app.py CHANGED
@@ -1,7 +1,18 @@
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
3
  def greet(v):
4
- return "Hello !"
 
 
 
5
 
6
  demo = gr.Interface(
7
  fn=greet,
 
1
  import gradio as gr
2
 
3
+ from mhr.predict_tools import MHRStoneRecognizeMgr, MHRVedioSimpleCuter
4
+
5
+ mgr = MHRStoneRecognizeMgr(
6
+ whole_pkl = "./whole_model.pkl",
7
+ hole_pkl = "./hole_model.pkl",
8
+ vedio_cutter = MHRVedioSimpleCuter(),
9
+ )
10
+
11
  def greet(v):
12
+ print(v)
13
+ ret = mgr.recognize(v)
14
+ print(ret)
15
+ return ret
16
 
17
  demo = gr.Interface(
18
  fn=greet,
hole_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a01490c907afe2d131f31c577f6bee4736ada5ae0c0f7de50ab71da7be2e42df
3
+ size 98027
mhr/common.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # linear algebra
2
+ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
3
+
4
+ import os
5
+
6
+ import random
7
+
8
+ # import
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchvision
13
+ import torchvision.transforms as T
14
+ #from torch.utils.data import DataLoader, SubsetRandomSampler
15
+ #from torch.utils.tensorboard import SummaryWriter
16
+ import pandas as pd
17
+ import numpy as np
18
+ import cv2 as cv
19
+ from PIL import Image,ImageDraw,ImageFont
20
+ import matplotlib.pyplot as plt
21
+
22
+ from sklearn import datasets
23
+ from sklearn.cluster import KMeans
24
+ #from sklearn.externals import joblib
25
+ import pickle
26
+ # other
27
+ import tr
28
+
mhr/config.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mhr.custom_transform import *
2
+ from mhr.custom_net import *
3
+ from mhr.custom_dateaset import *
4
+
5
+ # vedio data
6
+ data_dir = "/app/data/"
7
+
8
+ # model
9
+ whole_pkl_file = '/share_dir/mhr_data/whole_model.pkl'
10
+ hole_pkl_file = '/share_dir/mhr_data/hole_model.pkl'
11
+ speed_ratio = 5
12
+
13
+
14
+ tsfm_whole = T.Compose([
15
+ T.CenterCrop([720,720]),
16
+ T.Resize([360,360]),
17
+ T.ToTensor(),
18
+ T.Grayscale(num_output_channels=1),
19
+ ])
20
+ tsfm_whole4cv = T.Compose([
21
+ CV2PIL(),
22
+ T.CenterCrop([720,720]),
23
+ T.Resize([360,360]),
24
+ T.ToTensor(),
25
+ T.Grayscale(num_output_channels=1),
26
+ ])
27
+ #ds = MyDataset(set_name='whole', root=data_dir, transform=tsfm_whole, no_label=True)
28
+ #print("whole:", len(ds))
29
+
30
+
31
+ tsfm_hole = T.Compose([
32
+ PIL2CV(),
33
+ CvtColor(cv.COLOR_BGR2GRAY),
34
+ Threshold(125,255,cv.THRESH_BINARY),
35
+ CV2PIL(),
36
+ T.Resize([28,28]),
37
+ T.ToTensor(),
38
+ T.Grayscale(num_output_channels=1),
39
+ ])
40
+ tsfm_hole4cv = T.Compose([
41
+ CvtColor(cv.COLOR_BGR2GRAY),
42
+ Threshold(125,255,cv.THRESH_BINARY),
43
+ CV2PIL(),
44
+ T.Resize([28,28]),
45
+ T.ToTensor(),
46
+ T.Grayscale(num_output_channels=1),
47
+ ])
48
+ #ds = MyDataset(set_name='hole', root=data_dir, transform=tsfm_hole, no_label=True)
49
+ #print("hole:", len(ds))
50
+
51
+
52
+ tsfm_skill = T.Compose([
53
+ T.ToTensor(),
54
+ T.Grayscale(num_output_channels=1),
55
+ ])
56
+ tsfm_skill4cv = T.Compose([
57
+ CV2PIL(),
58
+ T.ToTensor(),
59
+ T.Grayscale(num_output_channels=1),
60
+ ])
61
+ #ds = MyDataset(set_name='skill', root=data_dir, transform=tsfm_skill, no_label=True)
62
+ #print("skill:", len(ds))
63
+
mhr/custom_dateaset.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mhr.common import *
2
+ from mhr.custom_transform import *
3
+
4
+ # define dataset
5
+ class MyDataset(torch.utils.data.Dataset): #torch.utils.data.Dataset
6
+ def __init__(self, set_name, root='.', transform=lambda x: x, no_label=False):
7
+ super(MyDataset,self).__init__()
8
+ self._root = root
9
+ self._transform = transform
10
+ self._df = pd.read_csv(self._root+'/label/'+set_name+'.csv')
11
+ self._df = self._df
12
+ self._no_label = no_label
13
+
14
+ def __getitem__(self, index):
15
+ if self._no_label:
16
+ name, label = self._df.loc[index]
17
+ label = name
18
+ else:
19
+ name, label = self._df[~self._df['label'].isna()].loc[index]
20
+ return self._transform(Image.open(self._root+'/image/'+name)), label
21
+
22
+ def __len__(self): # return size
23
+ if self._no_label:
24
+ return len(self._df)
25
+ else:
26
+ return len(self._df[~self._df['label'].isna()])
mhr/custom_net.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mhr.common import *
2
+ from mhr.custom_transform import *
3
+
4
+ def getGaborFilters(ksize, n_output, sigma_ratio_func, theta_ratio_func, lamda_ratio_func, gamma=0.5, psi=0, show=False):
5
+ filters = []
6
+ sigma = np.pi/2.0 # gaussian window width
7
+ theta = np.pi/2.0 # direction of cosine (raid)
8
+ lamda = np.pi/2.0 # wavelength of cosine
9
+ for i in range(n_output):
10
+ #print(i, sigma_ratio_func(i), theta_ratio_func(i), lamda_ratio_func(i))
11
+ kernel = cv.getGaborKernel((ksize,ksize),
12
+ sigma*sigma_ratio_func(i),
13
+ theta*theta_ratio_func(i),
14
+ lamda*lamda_ratio_func(i), gamma, psi, ktype=cv.CV_32F)
15
+ filters.append(kernel)
16
+ return filters
17
+
18
+
19
+ class TorchModelSaver:
20
+ def __init__(self):
21
+ pass
22
+ def save(self, model, path):
23
+ torch.save(model, path)
24
+ def load(self, path):
25
+ return torch.load(path)
26
+
27
+ class SklearnModelSaver:
28
+ def __init__(self):
29
+ pass
30
+ def save(self, model, path):
31
+ pickle.dumps(model, path)
32
+ def load(self, path):
33
+ return pickle.loads(path)
34
+
35
+ # for whole
36
+ class GaborFeatureNet(nn.Module):
37
+ def __init__(self, num_classes, show_filters=False, show_images=False):
38
+ super(GaborFeatureNet, self).__init__()
39
+ # config of gabor filters
40
+ ksize = 20
41
+ n_output = 12
42
+ sigma_func = lambda x: (x//4)/2+1
43
+ theta_func = lambda x: (x%4)/2
44
+ lamda_func = lambda x: x//4+1
45
+ filters = getGaborFilters(ksize, n_output,
46
+ sigma_func, theta_func, lamda_func
47
+ )
48
+ self.conv1 = torch.nn.Conv2d(1, n_output, (ksize,ksize),stride=1, bias=False)
49
+ self.conv1.weight.data = torch.Tensor(filters).unsqueeze(1)
50
+ self.pool1 = nn.Sigmoid()
51
+ self.pool2 = nn.MaxPool2d(5)
52
+ self.pool3 = nn.MaxPool2d(2)
53
+ #set_parameter_requires_grad(self.features, True)#�̶�������ȡ�����
54
+ for p in self.parameters():
55
+ p.requires_grad = False
56
+ self.classifier = nn.Sequential(
57
+ nn.Flatten(),
58
+ nn.Linear(12*34*34 , 1024),
59
+ nn.ReLU(),
60
+ nn.Linear(1024, 1024),
61
+ nn.Dropout(0.5),
62
+ nn.ReLU(),
63
+ nn.Linear(1024, num_classes)
64
+ )
65
+
66
+ def forward(self, img):
67
+ img = self.conv1(img)
68
+ img = self.pool1(img)
69
+ img = self.pool2(img) + self.pool2(-1*img)
70
+ img = self.pool3(img)
71
+ img = self.classifier(img)
72
+ return img
73
+
74
+ class WholeModelMgr:
75
+ def __init__(self, num_classes):
76
+ self.model = GaborFeatureNet(num_classes)
77
+ self.saver = TorchModelSaver()
78
+ def save(self, path):
79
+ self.saver.save(self.model, path)
80
+ def load(self, path):
81
+ self.model = self.saver.load(path)
82
+
83
+
84
+ # for hole
85
+ class GaborFeatureGen(nn.Module):
86
+ def __init__(self, num_classes, show_filters=False, show_images=False):
87
+ super(GaborFeatureGen, self).__init__()
88
+ # config of gabor filters
89
+ ksize = 20
90
+ n_output = 12
91
+ sigma_func = lambda x: ((x//4)/2+2)/4
92
+ theta_func = lambda x: (x%4)/2
93
+ lamda_func = lambda x: (x//4+1)/2
94
+ self.show_filters = show_filters
95
+ self.show_images = show_images
96
+ self.wins, self.dirs = self._get_wins_dirs(n_output, theta_func)
97
+ filters = getGaborFilters(ksize, n_output,
98
+ sigma_func, theta_func, lamda_func,
99
+ psi=np.pi/2)
100
+ self.conv1 = torch.nn.Conv2d(1, n_output, (ksize,ksize),stride=1,padding='same', bias=False)
101
+ self.conv1.weight.data = torch.Tensor(filters).unsqueeze(1)
102
+ self.pool1 = nn.Sigmoid()
103
+ self.pool2 = nn.MaxPool2d(1)
104
+ self.pool3 = nn.MaxPool2d(2)
105
+ if show_filters:
106
+ self._show_img(self.wins, self.dirs, filters)
107
+
108
+ def forward(self, img):
109
+ img = self.conv1(img)
110
+ img = self.pool1(img)
111
+ #img = self.pool2(img) + self.pool2(-1*img)
112
+ img = self.pool3(img)
113
+ if self.show_images:
114
+ self._show_img(self.wins, self.dirs, img[0])
115
+ return nn.Flatten()(img)
116
+
117
+ def _show_img(self, wins,dirs,imgs):
118
+ plt.figure(1)
119
+ for i in range(len(imgs)):
120
+ plt.subplot(wins, dirs, i+1)
121
+ if type(imgs[i]) is np.ndarray:
122
+ plt.imshow(imgs[i], cmap=plt.get_cmap('gray'))
123
+ else:
124
+ plt.imshow(T.functional.to_pil_image(imgs[i]), cmap=plt.get_cmap('gray'))
125
+ plt.show()
126
+
127
+ def _get_wins_dirs(self, n_output, theta_func):
128
+ dirs = len(set([ theta_func(i) for i in range(n_output) ]))
129
+ return n_output//dirs, dirs
130
+
131
+ class HoleModelMgr:
132
+ def __init__(self, n_clusters):
133
+ self.feat_model = GaborFeatureGen(0)
134
+ self.model = KMeans(n_clusters)
135
+ self.saver = SklearnModelSaver()
136
+ def save(self, path):
137
+ self.saver.save(self.model, path)
138
+ def load(self, path):
139
+ self.model = self.saver.load(path)
140
+
141
+
142
+
143
+ # for skill
144
+ class MyTrRecognizeNet(torch.nn.Module):
145
+ def __init__(self, image_padding):
146
+ super(MyTrRecognizeNet, self).__init__()
147
+ ip = image_padding if image_padding is not None else 0
148
+ nm_p = [5, 22+16*0]
149
+ nm_h = 18
150
+ nm_w = 16*10
151
+ lv_p = [29, 193]
152
+ lv_h = 18
153
+ lv_w = 44
154
+ #print([ x-ip for x in nm_p], nm_h+ip*2, nm_w+ip*2)
155
+ #print([ x-ip for x in lv_p], lv_h+ip*2, lv_w+ip*2)
156
+ self.tsfm_nm = T.Compose([
157
+ TensorCut([ x-ip for x in nm_p], nm_h+ip*2, nm_w+ip*2), # ����
158
+ ])
159
+ self.tsfm_lv = T.Compose([
160
+ TensorCut([ x-ip for x in lv_p], lv_h+ip*2, lv_w+ip*2), # ����
161
+ ])
162
+
163
+ def batch_forward(self, imgs):
164
+ return [ self.single_forward(img) for img in imgs ]
165
+ def single_forward(self, img):
166
+ img_nm = self.tsfm_nm(img.clone().detach())
167
+ img_lv = self.tsfm_lv(img.clone().detach())
168
+ nm = tr.recognize(T.functional.to_pil_image(img_nm))
169
+ lv = tr.recognize(T.functional.to_pil_image(img_lv))
170
+ return nm,lv
171
+ def forward(self, img):
172
+ if len(img.shape) == 4:
173
+ return self.batch_forward(img)
174
+ elif len(img.shape) == 3:
175
+ return self.single_forward(img)
176
+
177
+ def forward_bak(self, img):
178
+ if len(img.shape) == 4:
179
+ img = img.squeeze(0)
180
+ img_nm = self.tsfm_nm(img.clone().detach())
181
+ img_lv = self.tsfm_lv(img.clone().detach())
182
+ nm = tr.recognize(T.functional.to_pil_image(img_nm))
183
+ lv = tr.recognize(T.functional.to_pil_image(img_lv))
184
+ return nm,lv
mhr/custom_transform.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mhr.common import *
2
+
3
+ class PIL2CV(torch.nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+ def forward(self, img):
7
+ return cv.cvtColor(np.asarray(img), cv.COLOR_RGB2BGR)
8
+
9
+ class CV2PIL(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ def forward(self, img):
13
+ return Image.fromarray(cv.cvtColor(img, cv.COLOR_BGR2RGB))
14
+
15
+ class Tensor2CV(torch.nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ def forward(self, img):
19
+ if len(img.shape) == 4:
20
+ img = img.squeeze(0)
21
+ img = img.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
22
+ img = cv.cvtColor(img, cv.COLOR_RGB2BGR)
23
+ return img
24
+
25
+
26
+ class CvtColor(torch.nn.Module):
27
+ def __init__(self, cvt):
28
+ super().__init__()
29
+ self._cvt = cvt
30
+ def forward(self, img):
31
+ return cv.cvtColor(img, self._cvt)
32
+
33
+ class GaussianBlur(torch.nn.Module):
34
+ def __init__(self, kernel, sigma):
35
+ super().__init__()
36
+ self._kernel = kernel
37
+ self._sigma = sigma
38
+ def forward(self, img):
39
+ return cv.GaussianBlur(img, self._kernel, self._sigma)
40
+
41
+ class EqualizeHist(torch.nn.Module):
42
+ def __init__(self):
43
+ super().__init__()
44
+ def forward(self, img):
45
+ return cv.equalizeHist(img)
46
+
47
+ class SobelBitwiseOrXY(torch.nn.Module):
48
+ def __init__(self):
49
+ super().__init__()
50
+ def forward(self, img):
51
+ sobelx = cv.Sobel(img, cv.CV_64F, 1, 0)
52
+ sobely = cv.Sobel(img, cv.CV_64F, 0, 1)
53
+ sobelx = cv.convertScaleAbs(sobelx)
54
+ sobely = cv.convertScaleAbs(sobely)
55
+ return cv.bitwise_or(sobelx, sobely)
56
+
57
+ class Threshold(torch.nn.Module):
58
+ def __init__(self, thresh, maxval, tt):
59
+ super().__init__()
60
+ self._thresh = thresh
61
+ self._maxval = maxval
62
+ self._tt = tt
63
+ def forward(self, img):
64
+ _, tmp = cv.threshold(img, self._thresh, self._maxval, self._tt)
65
+ return tmp
66
+
67
+ class Cut(torch.nn.Module):
68
+ def __init__(self, point, offsetx, offsety):
69
+ super().__init__()
70
+ self._p = point
71
+ self._offsetx = offsetx
72
+ self._offsety = offsety
73
+ def forward(self, img):
74
+ p = self._p
75
+ return img[p[0]:p[0]+self._offsetx, p[1]:p[1]+self._offsety ]
76
+
77
+
78
+
79
+ class TensorCut(torch.nn.Module):
80
+ def __init__(self, point, offsetx, offsety):
81
+ super().__init__()
82
+ self._p = point
83
+ self._offsetx = offsetx
84
+ self._offsety = offsety
85
+ def forward(self, img):
86
+ p = self._p
87
+ return img[:, p[0]:p[0]+self._offsetx, p[1]:p[1]+self._offsety ]
mhr/predict_tools.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mhr.config import *
2
+
3
+
4
+ class MHRVedioCuter:
5
+ def __init__(self, speed_ratio=1):
6
+ ## config
7
+ #self.part_pos = {
8
+ # 'pos':[(628,350),(993,565)],
9
+ # 'page': [(781,575),(848,600)],
10
+ # 'hole': [(1167,197),(1250,227)],
11
+ # 'skill': [(1010,260),(1254,600)],
12
+ #}
13
+ self.part_pos = {
14
+ 'hole': [1166,200,28,26],
15
+ 'skill': [1014,264,240,50],
16
+ }
17
+ self.speed_ratio = speed_ratio
18
+
19
+
20
+ def iter(self, name):
21
+ vc = cv.VideoCapture(name)
22
+ fps = vc.get(cv.CAP_PROP_FPS)
23
+ print("vedio:", vc.isOpened(), fps)
24
+ label = "00:00:{:05.2f}({})"
25
+ rval = True
26
+ idx=0
27
+ while rval:
28
+ rval, img = vc.read()
29
+ idx+=1
30
+ if rval and idx%self.speed_ratio == 0:
31
+ yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
32
+
33
+ vc.release()
34
+
35
+ def _cut_whole(self, img):
36
+ #pos_w, pos_h, w, h = self.part_pos['skill']
37
+ return img
38
+
39
+ def _cut_hole(self, img):
40
+ pos_w, pos_h, w, h = self.part_pos['hole']
41
+ return [ img[pos_h:pos_h+h, pos_w+w*i:pos_w+w*i+w] for i in range(3) ]
42
+
43
+ def _cut_skill(self, img):
44
+ pos_w, pos_h, w, h = self.part_pos['skill']
45
+ return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
46
+
47
+ class MHRVedioSimpleCuter(MHRVedioCuter):
48
+ def __init__(self):
49
+ super(MHRVedioSimpleCuter, self).__init__()
50
+
51
+ def iter(self, v):
52
+ for img in v:
53
+ yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
54
+
55
+
56
+
57
+ class MHRStoneRecognizeMgr:
58
+ def __init__(self, whole_pkl, hole_pkl, vedio_cutter):
59
+ self.mapping_hole = [0,2,1,3,4]
60
+ cp = torch.load(whole_pkl)
61
+ self.whole_model = GaborFeatureNet(num_classes=2)
62
+ self.whole_model.load_state_dict(cp['model'])
63
+ if torch.cuda.is_available():
64
+ self.whole_model = self.whole_model.cuda()
65
+ #self.whole_model = torch.load(whole_pkl)
66
+ self.hole_feat_model = GaborFeatureGen(0)
67
+ if torch.cuda.is_available():
68
+ self.hole_feat_model = self.hole_feat_model.cuda()
69
+ self.hole_model = pickle.load(open(hole_pkl, 'rb'))
70
+ self.skill_model = MyTrRecognizeNet(image_padding=2)
71
+ self._vedio_cutter = vedio_cutter
72
+
73
+ def recognize(self, vname, fname=None):
74
+ def dump(rr):
75
+ return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
76
+ results = []
77
+ i=0
78
+ for data in self._vedio_cutter.iter(vname):
79
+ i+=1
80
+ #if i < 200 and i > 250:
81
+ # continue
82
+ #if len(results) == 1:
83
+ # break
84
+ #whole
85
+ data_whole = tsfm_whole4cv(data[0])
86
+ data_whole = data_whole.unsqueeze(0)
87
+ if torch.cuda.is_available():
88
+ data_whole = data_whole.cuda()
89
+ ret = self.whole_model(data_whole)
90
+ if ret[0][1] - ret[0][0] < 2:
91
+ continue
92
+ ret_whole = True
93
+ #new hole
94
+ data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
95
+ if torch.cuda.is_available():
96
+ data_hole = data_hole.cuda()
97
+ output = self.hole_feat_model(data_hole)
98
+ df = pd.DataFrame(output.tolist())
99
+ ret_hole = list(self.hole_model.predict(df))
100
+ #new skill
101
+ data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
102
+ ret = self.skill_model(data_skill)
103
+ ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
104
+ # reuslt
105
+ result = [data[3], ret_hole, ret_skill]
106
+ if len(results) > 0 and dump(results[-1]) == dump(result):
107
+ continue
108
+ results.append(result)
109
+ if fname:
110
+ with open(fname, 'w') as f:
111
+ for result in results:
112
+ line = result[0]
113
+ line += ','
114
+ line += ','.join([ str(x) for x in result[1] ])
115
+ line += ','
116
+ line += ','.join([ x[0]+","+x[1] for x in result[2] ])
117
+ line += '\n'
118
+ f.write(line)
119
+ return results
120
+
121
+
122
+
123
+ mgr = MHRStoneRecognizeMgr(
124
+ whole_pkl = whole_pkl_file,
125
+ hole_pkl = hole_pkl_file,
126
+ vedio_cutter = MHRVedioCuter(speed_ratio),
127
+ )
whole_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:267011c5b8bb860dd61a4b3e769582d5f69296f9432a51b32b93ea8b168ccc72
3
+ size 61054629