miojizzy commited on
Commit
87f517a
·
1 Parent(s): 5dcafd1

Update mhr/predict_tools.py

Browse files
Files changed (1) hide show
  1. mhr/predict_tools.py +28 -36
mhr/predict_tools.py CHANGED
@@ -44,14 +44,6 @@ class MHRVedioCuter:
44
  pos_w, pos_h, w, h = self.part_pos['skill']
45
  return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
46
 
47
- class MHRVedioSimpleCuter(MHRVedioCuter):
48
- def __init__(self):
49
- super(MHRVedioSimpleCuter, self).__init__()
50
-
51
- def iter(self, v):
52
- for img in v:
53
- yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
54
-
55
 
56
 
57
  class MHRStoneRecognizeMgr:
@@ -70,40 +62,40 @@ class MHRStoneRecognizeMgr:
70
  self.skill_model = MyTrRecognizeNet(image_padding=2)
71
  self._vedio_cutter = vedio_cutter
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def recognize(self, vname, fname=None):
74
  def dump(rr):
75
  return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
76
  results = []
77
  i=0
78
  for data in self._vedio_cutter.iter(vname):
79
- i+=1
80
- if i % 30 == 0:
81
- print(vname, i)
82
- #if len(results) == 1:
83
- # break
84
- #whole
85
- data_whole = tsfm_whole4cv(data[0])
86
- data_whole = data_whole.unsqueeze(0)
87
- if torch.cuda.is_available():
88
- data_whole = data_whole.cuda()
89
- ret = self.whole_model(data_whole)
90
- if ret[0][1] - ret[0][0] < 2:
91
- continue
92
- ret_whole = True
93
- #new hole
94
- data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
95
- if torch.cuda.is_available():
96
- data_hole = data_hole.cuda()
97
- output = self.hole_feat_model(data_hole)
98
- df = pd.DataFrame(output.tolist())
99
- ret_hole = list(self.hole_model.predict(df))
100
- #new skill
101
- data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
102
- ret = self.skill_model(data_skill)
103
- ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
104
- # reuslt
105
- result = [data[3], ret_hole, ret_skill]
106
- if len(results) > 0 and dump(results[-1]) == dump(result):
107
  continue
108
  results.append(result)
109
  if fname:
 
44
  pos_w, pos_h, w, h = self.part_pos['skill']
45
  return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
46
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  class MHRStoneRecognizeMgr:
 
62
  self.skill_model = MyTrRecognizeNet(image_padding=2)
63
  self._vedio_cutter = vedio_cutter
64
 
65
+ def recognize_image(self, data):
66
+ data_whole = tsfm_whole4cv(data[0])
67
+ data_whole = data_whole.unsqueeze(0)
68
+ if torch.cuda.is_available():
69
+ data_whole = data_whole.cuda()
70
+ ret = self.whole_model(data_whole)
71
+ if ret[0][1] - ret[0][0] < 2:
72
+ return False, []
73
+ #new hole
74
+ data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
75
+ if torch.cuda.is_available():
76
+ data_hole = data_hole.cuda()
77
+ output = self.hole_feat_model(data_hole)
78
+ df = pd.DataFrame(output.tolist())
79
+ ret_hole = list(self.hole_model.predict(df))
80
+ #new skill
81
+ data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
82
+ ret = self.skill_model(data_skill)
83
+ ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
84
+ # reuslt
85
+ result = [data[3], ret_hole, ret_skill]
86
+ return True, result
87
+
88
+ if len(results) > 0 and dump(results[-1]) == dump(result):
89
+ return
90
+
91
  def recognize(self, vname, fname=None):
92
  def dump(rr):
93
  return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
94
  results = []
95
  i=0
96
  for data in self._vedio_cutter.iter(vname):
97
+ ok, result = self.recognize_image(data)
98
+ if not ok or (len(results) > 0 and dump(results[-1]) == dump(result)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  continue
100
  results.append(result)
101
  if fname: