Spaces:
Sleeping
Sleeping
Update mhr/predict_tools.py
Browse files- mhr/predict_tools.py +28 -36
mhr/predict_tools.py
CHANGED
|
@@ -44,14 +44,6 @@ class MHRVedioCuter:
|
|
| 44 |
pos_w, pos_h, w, h = self.part_pos['skill']
|
| 45 |
return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
|
| 46 |
|
| 47 |
-
class MHRVedioSimpleCuter(MHRVedioCuter):
|
| 48 |
-
def __init__(self):
|
| 49 |
-
super(MHRVedioSimpleCuter, self).__init__()
|
| 50 |
-
|
| 51 |
-
def iter(self, v):
|
| 52 |
-
for img in v:
|
| 53 |
-
yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
|
| 54 |
-
|
| 55 |
|
| 56 |
|
| 57 |
class MHRStoneRecognizeMgr:
|
|
@@ -70,40 +62,40 @@ class MHRStoneRecognizeMgr:
|
|
| 70 |
self.skill_model = MyTrRecognizeNet(image_padding=2)
|
| 71 |
self._vedio_cutter = vedio_cutter
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def recognize(self, vname, fname=None):
|
| 74 |
def dump(rr):
|
| 75 |
return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
|
| 76 |
results = []
|
| 77 |
i=0
|
| 78 |
for data in self._vedio_cutter.iter(vname):
|
| 79 |
-
|
| 80 |
-
if
|
| 81 |
-
print(vname, i)
|
| 82 |
-
#if len(results) == 1:
|
| 83 |
-
# break
|
| 84 |
-
#whole
|
| 85 |
-
data_whole = tsfm_whole4cv(data[0])
|
| 86 |
-
data_whole = data_whole.unsqueeze(0)
|
| 87 |
-
if torch.cuda.is_available():
|
| 88 |
-
data_whole = data_whole.cuda()
|
| 89 |
-
ret = self.whole_model(data_whole)
|
| 90 |
-
if ret[0][1] - ret[0][0] < 2:
|
| 91 |
-
continue
|
| 92 |
-
ret_whole = True
|
| 93 |
-
#new hole
|
| 94 |
-
data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
|
| 95 |
-
if torch.cuda.is_available():
|
| 96 |
-
data_hole = data_hole.cuda()
|
| 97 |
-
output = self.hole_feat_model(data_hole)
|
| 98 |
-
df = pd.DataFrame(output.tolist())
|
| 99 |
-
ret_hole = list(self.hole_model.predict(df))
|
| 100 |
-
#new skill
|
| 101 |
-
data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
|
| 102 |
-
ret = self.skill_model(data_skill)
|
| 103 |
-
ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
|
| 104 |
-
# reuslt
|
| 105 |
-
result = [data[3], ret_hole, ret_skill]
|
| 106 |
-
if len(results) > 0 and dump(results[-1]) == dump(result):
|
| 107 |
continue
|
| 108 |
results.append(result)
|
| 109 |
if fname:
|
|
|
|
| 44 |
pos_w, pos_h, w, h = self.part_pos['skill']
|
| 45 |
return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
class MHRStoneRecognizeMgr:
|
|
|
|
| 62 |
self.skill_model = MyTrRecognizeNet(image_padding=2)
|
| 63 |
self._vedio_cutter = vedio_cutter
|
| 64 |
|
| 65 |
+
def recognize_image(self, data):
|
| 66 |
+
data_whole = tsfm_whole4cv(data[0])
|
| 67 |
+
data_whole = data_whole.unsqueeze(0)
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
data_whole = data_whole.cuda()
|
| 70 |
+
ret = self.whole_model(data_whole)
|
| 71 |
+
if ret[0][1] - ret[0][0] < 2:
|
| 72 |
+
return False, []
|
| 73 |
+
#new hole
|
| 74 |
+
data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
|
| 75 |
+
if torch.cuda.is_available():
|
| 76 |
+
data_hole = data_hole.cuda()
|
| 77 |
+
output = self.hole_feat_model(data_hole)
|
| 78 |
+
df = pd.DataFrame(output.tolist())
|
| 79 |
+
ret_hole = list(self.hole_model.predict(df))
|
| 80 |
+
#new skill
|
| 81 |
+
data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
|
| 82 |
+
ret = self.skill_model(data_skill)
|
| 83 |
+
ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
|
| 84 |
+
# reuslt
|
| 85 |
+
result = [data[3], ret_hole, ret_skill]
|
| 86 |
+
return True, result
|
| 87 |
+
|
| 88 |
+
if len(results) > 0 and dump(results[-1]) == dump(result):
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
def recognize(self, vname, fname=None):
|
| 92 |
def dump(rr):
|
| 93 |
return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
|
| 94 |
results = []
|
| 95 |
i=0
|
| 96 |
for data in self._vedio_cutter.iter(vname):
|
| 97 |
+
ok, result = self.recognize_image(data)
|
| 98 |
+
if not ok or (len(results) > 0 and dump(results[-1]) == dump(result)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
continue
|
| 100 |
results.append(result)
|
| 101 |
if fname:
|