Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- app.py +12 -1
- hole_model.pkl +3 -0
- mhr/common.py +28 -0
- mhr/config.py +63 -0
- mhr/custom_dateaset.py +26 -0
- mhr/custom_net.py +184 -0
- mhr/custom_transform.py +87 -0
- mhr/predict_tools.py +127 -0
- whole_model.pkl +3 -0
app.py
CHANGED
|
@@ -1,7 +1,18 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
def greet(v):
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
demo = gr.Interface(
|
| 7 |
fn=greet,
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
from mhr.predict_tools import MHRStoneRecognizeMgr, MHRVedioSimpleCuter
|
| 4 |
+
|
| 5 |
+
mgr = MHRStoneRecognizeMgr(
|
| 6 |
+
whole_pkl = "./whole_model.pkl",
|
| 7 |
+
hole_pkl = "./hole_model.pkl",
|
| 8 |
+
vedio_cutter = MHRVedioSimpleCuter(),
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
def greet(v):
|
| 12 |
+
print(v)
|
| 13 |
+
ret = mgr.recognize(v)
|
| 14 |
+
print(ret)
|
| 15 |
+
return ret
|
| 16 |
|
| 17 |
demo = gr.Interface(
|
| 18 |
fn=greet,
|
hole_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a01490c907afe2d131f31c577f6bee4736ada5ae0c0f7de50ab71da7be2e42df
|
| 3 |
+
size 98027
|
mhr/common.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np # linear algebra
|
| 2 |
+
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
# import
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torchvision
|
| 13 |
+
import torchvision.transforms as T
|
| 14 |
+
#from torch.utils.data import DataLoader, SubsetRandomSampler
|
| 15 |
+
#from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import numpy as np
|
| 18 |
+
import cv2 as cv
|
| 19 |
+
from PIL import Image,ImageDraw,ImageFont
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
|
| 22 |
+
from sklearn import datasets
|
| 23 |
+
from sklearn.cluster import KMeans
|
| 24 |
+
#from sklearn.externals import joblib
|
| 25 |
+
import pickle
|
| 26 |
+
# other
|
| 27 |
+
import tr
|
| 28 |
+
|
mhr/config.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mhr.custom_transform import *
|
| 2 |
+
from mhr.custom_net import *
|
| 3 |
+
from mhr.custom_dateaset import *
|
| 4 |
+
|
| 5 |
+
# vedio data
|
| 6 |
+
data_dir = "/app/data/"
|
| 7 |
+
|
| 8 |
+
# model
|
| 9 |
+
whole_pkl_file = '/share_dir/mhr_data/whole_model.pkl'
|
| 10 |
+
hole_pkl_file = '/share_dir/mhr_data/hole_model.pkl'
|
| 11 |
+
speed_ratio = 5
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
tsfm_whole = T.Compose([
|
| 15 |
+
T.CenterCrop([720,720]),
|
| 16 |
+
T.Resize([360,360]),
|
| 17 |
+
T.ToTensor(),
|
| 18 |
+
T.Grayscale(num_output_channels=1),
|
| 19 |
+
])
|
| 20 |
+
tsfm_whole4cv = T.Compose([
|
| 21 |
+
CV2PIL(),
|
| 22 |
+
T.CenterCrop([720,720]),
|
| 23 |
+
T.Resize([360,360]),
|
| 24 |
+
T.ToTensor(),
|
| 25 |
+
T.Grayscale(num_output_channels=1),
|
| 26 |
+
])
|
| 27 |
+
#ds = MyDataset(set_name='whole', root=data_dir, transform=tsfm_whole, no_label=True)
|
| 28 |
+
#print("whole:", len(ds))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
tsfm_hole = T.Compose([
|
| 32 |
+
PIL2CV(),
|
| 33 |
+
CvtColor(cv.COLOR_BGR2GRAY),
|
| 34 |
+
Threshold(125,255,cv.THRESH_BINARY),
|
| 35 |
+
CV2PIL(),
|
| 36 |
+
T.Resize([28,28]),
|
| 37 |
+
T.ToTensor(),
|
| 38 |
+
T.Grayscale(num_output_channels=1),
|
| 39 |
+
])
|
| 40 |
+
tsfm_hole4cv = T.Compose([
|
| 41 |
+
CvtColor(cv.COLOR_BGR2GRAY),
|
| 42 |
+
Threshold(125,255,cv.THRESH_BINARY),
|
| 43 |
+
CV2PIL(),
|
| 44 |
+
T.Resize([28,28]),
|
| 45 |
+
T.ToTensor(),
|
| 46 |
+
T.Grayscale(num_output_channels=1),
|
| 47 |
+
])
|
| 48 |
+
#ds = MyDataset(set_name='hole', root=data_dir, transform=tsfm_hole, no_label=True)
|
| 49 |
+
#print("hole:", len(ds))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
tsfm_skill = T.Compose([
|
| 53 |
+
T.ToTensor(),
|
| 54 |
+
T.Grayscale(num_output_channels=1),
|
| 55 |
+
])
|
| 56 |
+
tsfm_skill4cv = T.Compose([
|
| 57 |
+
CV2PIL(),
|
| 58 |
+
T.ToTensor(),
|
| 59 |
+
T.Grayscale(num_output_channels=1),
|
| 60 |
+
])
|
| 61 |
+
#ds = MyDataset(set_name='skill', root=data_dir, transform=tsfm_skill, no_label=True)
|
| 62 |
+
#print("skill:", len(ds))
|
| 63 |
+
|
mhr/custom_dateaset.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mhr.common import *
|
| 2 |
+
from mhr.custom_transform import *
|
| 3 |
+
|
| 4 |
+
# define dataset
|
| 5 |
+
class MyDataset(torch.utils.data.Dataset): #torch.utils.data.Dataset
|
| 6 |
+
def __init__(self, set_name, root='.', transform=lambda x: x, no_label=False):
|
| 7 |
+
super(MyDataset,self).__init__()
|
| 8 |
+
self._root = root
|
| 9 |
+
self._transform = transform
|
| 10 |
+
self._df = pd.read_csv(self._root+'/label/'+set_name+'.csv')
|
| 11 |
+
self._df = self._df
|
| 12 |
+
self._no_label = no_label
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, index):
|
| 15 |
+
if self._no_label:
|
| 16 |
+
name, label = self._df.loc[index]
|
| 17 |
+
label = name
|
| 18 |
+
else:
|
| 19 |
+
name, label = self._df[~self._df['label'].isna()].loc[index]
|
| 20 |
+
return self._transform(Image.open(self._root+'/image/'+name)), label
|
| 21 |
+
|
| 22 |
+
def __len__(self): # return size
|
| 23 |
+
if self._no_label:
|
| 24 |
+
return len(self._df)
|
| 25 |
+
else:
|
| 26 |
+
return len(self._df[~self._df['label'].isna()])
|
mhr/custom_net.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mhr.common import *
|
| 2 |
+
from mhr.custom_transform import *
|
| 3 |
+
|
| 4 |
+
def getGaborFilters(ksize, n_output, sigma_ratio_func, theta_ratio_func, lamda_ratio_func, gamma=0.5, psi=0, show=False):
|
| 5 |
+
filters = []
|
| 6 |
+
sigma = np.pi/2.0 # gaussian window width
|
| 7 |
+
theta = np.pi/2.0 # direction of cosine (raid)
|
| 8 |
+
lamda = np.pi/2.0 # wavelength of cosine
|
| 9 |
+
for i in range(n_output):
|
| 10 |
+
#print(i, sigma_ratio_func(i), theta_ratio_func(i), lamda_ratio_func(i))
|
| 11 |
+
kernel = cv.getGaborKernel((ksize,ksize),
|
| 12 |
+
sigma*sigma_ratio_func(i),
|
| 13 |
+
theta*theta_ratio_func(i),
|
| 14 |
+
lamda*lamda_ratio_func(i), gamma, psi, ktype=cv.CV_32F)
|
| 15 |
+
filters.append(kernel)
|
| 16 |
+
return filters
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TorchModelSaver:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
pass
|
| 22 |
+
def save(self, model, path):
|
| 23 |
+
torch.save(model, path)
|
| 24 |
+
def load(self, path):
|
| 25 |
+
return torch.load(path)
|
| 26 |
+
|
| 27 |
+
class SklearnModelSaver:
|
| 28 |
+
def __init__(self):
|
| 29 |
+
pass
|
| 30 |
+
def save(self, model, path):
|
| 31 |
+
pickle.dumps(model, path)
|
| 32 |
+
def load(self, path):
|
| 33 |
+
return pickle.loads(path)
|
| 34 |
+
|
| 35 |
+
# for whole
|
| 36 |
+
class GaborFeatureNet(nn.Module):
|
| 37 |
+
def __init__(self, num_classes, show_filters=False, show_images=False):
|
| 38 |
+
super(GaborFeatureNet, self).__init__()
|
| 39 |
+
# config of gabor filters
|
| 40 |
+
ksize = 20
|
| 41 |
+
n_output = 12
|
| 42 |
+
sigma_func = lambda x: (x//4)/2+1
|
| 43 |
+
theta_func = lambda x: (x%4)/2
|
| 44 |
+
lamda_func = lambda x: x//4+1
|
| 45 |
+
filters = getGaborFilters(ksize, n_output,
|
| 46 |
+
sigma_func, theta_func, lamda_func
|
| 47 |
+
)
|
| 48 |
+
self.conv1 = torch.nn.Conv2d(1, n_output, (ksize,ksize),stride=1, bias=False)
|
| 49 |
+
self.conv1.weight.data = torch.Tensor(filters).unsqueeze(1)
|
| 50 |
+
self.pool1 = nn.Sigmoid()
|
| 51 |
+
self.pool2 = nn.MaxPool2d(5)
|
| 52 |
+
self.pool3 = nn.MaxPool2d(2)
|
| 53 |
+
#set_parameter_requires_grad(self.features, True)#�̶�������ȡ�����
|
| 54 |
+
for p in self.parameters():
|
| 55 |
+
p.requires_grad = False
|
| 56 |
+
self.classifier = nn.Sequential(
|
| 57 |
+
nn.Flatten(),
|
| 58 |
+
nn.Linear(12*34*34 , 1024),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Linear(1024, 1024),
|
| 61 |
+
nn.Dropout(0.5),
|
| 62 |
+
nn.ReLU(),
|
| 63 |
+
nn.Linear(1024, num_classes)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, img):
|
| 67 |
+
img = self.conv1(img)
|
| 68 |
+
img = self.pool1(img)
|
| 69 |
+
img = self.pool2(img) + self.pool2(-1*img)
|
| 70 |
+
img = self.pool3(img)
|
| 71 |
+
img = self.classifier(img)
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
class WholeModelMgr:
|
| 75 |
+
def __init__(self, num_classes):
|
| 76 |
+
self.model = GaborFeatureNet(num_classes)
|
| 77 |
+
self.saver = TorchModelSaver()
|
| 78 |
+
def save(self, path):
|
| 79 |
+
self.saver.save(self.model, path)
|
| 80 |
+
def load(self, path):
|
| 81 |
+
self.model = self.saver.load(path)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# for hole
|
| 85 |
+
class GaborFeatureGen(nn.Module):
|
| 86 |
+
def __init__(self, num_classes, show_filters=False, show_images=False):
|
| 87 |
+
super(GaborFeatureGen, self).__init__()
|
| 88 |
+
# config of gabor filters
|
| 89 |
+
ksize = 20
|
| 90 |
+
n_output = 12
|
| 91 |
+
sigma_func = lambda x: ((x//4)/2+2)/4
|
| 92 |
+
theta_func = lambda x: (x%4)/2
|
| 93 |
+
lamda_func = lambda x: (x//4+1)/2
|
| 94 |
+
self.show_filters = show_filters
|
| 95 |
+
self.show_images = show_images
|
| 96 |
+
self.wins, self.dirs = self._get_wins_dirs(n_output, theta_func)
|
| 97 |
+
filters = getGaborFilters(ksize, n_output,
|
| 98 |
+
sigma_func, theta_func, lamda_func,
|
| 99 |
+
psi=np.pi/2)
|
| 100 |
+
self.conv1 = torch.nn.Conv2d(1, n_output, (ksize,ksize),stride=1,padding='same', bias=False)
|
| 101 |
+
self.conv1.weight.data = torch.Tensor(filters).unsqueeze(1)
|
| 102 |
+
self.pool1 = nn.Sigmoid()
|
| 103 |
+
self.pool2 = nn.MaxPool2d(1)
|
| 104 |
+
self.pool3 = nn.MaxPool2d(2)
|
| 105 |
+
if show_filters:
|
| 106 |
+
self._show_img(self.wins, self.dirs, filters)
|
| 107 |
+
|
| 108 |
+
def forward(self, img):
|
| 109 |
+
img = self.conv1(img)
|
| 110 |
+
img = self.pool1(img)
|
| 111 |
+
#img = self.pool2(img) + self.pool2(-1*img)
|
| 112 |
+
img = self.pool3(img)
|
| 113 |
+
if self.show_images:
|
| 114 |
+
self._show_img(self.wins, self.dirs, img[0])
|
| 115 |
+
return nn.Flatten()(img)
|
| 116 |
+
|
| 117 |
+
def _show_img(self, wins,dirs,imgs):
|
| 118 |
+
plt.figure(1)
|
| 119 |
+
for i in range(len(imgs)):
|
| 120 |
+
plt.subplot(wins, dirs, i+1)
|
| 121 |
+
if type(imgs[i]) is np.ndarray:
|
| 122 |
+
plt.imshow(imgs[i], cmap=plt.get_cmap('gray'))
|
| 123 |
+
else:
|
| 124 |
+
plt.imshow(T.functional.to_pil_image(imgs[i]), cmap=plt.get_cmap('gray'))
|
| 125 |
+
plt.show()
|
| 126 |
+
|
| 127 |
+
def _get_wins_dirs(self, n_output, theta_func):
|
| 128 |
+
dirs = len(set([ theta_func(i) for i in range(n_output) ]))
|
| 129 |
+
return n_output//dirs, dirs
|
| 130 |
+
|
| 131 |
+
class HoleModelMgr:
|
| 132 |
+
def __init__(self, n_clusters):
|
| 133 |
+
self.feat_model = GaborFeatureGen(0)
|
| 134 |
+
self.model = KMeans(n_clusters)
|
| 135 |
+
self.saver = SklearnModelSaver()
|
| 136 |
+
def save(self, path):
|
| 137 |
+
self.saver.save(self.model, path)
|
| 138 |
+
def load(self, path):
|
| 139 |
+
self.model = self.saver.load(path)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# for skill
|
| 144 |
+
class MyTrRecognizeNet(torch.nn.Module):
|
| 145 |
+
def __init__(self, image_padding):
|
| 146 |
+
super(MyTrRecognizeNet, self).__init__()
|
| 147 |
+
ip = image_padding if image_padding is not None else 0
|
| 148 |
+
nm_p = [5, 22+16*0]
|
| 149 |
+
nm_h = 18
|
| 150 |
+
nm_w = 16*10
|
| 151 |
+
lv_p = [29, 193]
|
| 152 |
+
lv_h = 18
|
| 153 |
+
lv_w = 44
|
| 154 |
+
#print([ x-ip for x in nm_p], nm_h+ip*2, nm_w+ip*2)
|
| 155 |
+
#print([ x-ip for x in lv_p], lv_h+ip*2, lv_w+ip*2)
|
| 156 |
+
self.tsfm_nm = T.Compose([
|
| 157 |
+
TensorCut([ x-ip for x in nm_p], nm_h+ip*2, nm_w+ip*2), # ����
|
| 158 |
+
])
|
| 159 |
+
self.tsfm_lv = T.Compose([
|
| 160 |
+
TensorCut([ x-ip for x in lv_p], lv_h+ip*2, lv_w+ip*2), # ����
|
| 161 |
+
])
|
| 162 |
+
|
| 163 |
+
def batch_forward(self, imgs):
|
| 164 |
+
return [ self.single_forward(img) for img in imgs ]
|
| 165 |
+
def single_forward(self, img):
|
| 166 |
+
img_nm = self.tsfm_nm(img.clone().detach())
|
| 167 |
+
img_lv = self.tsfm_lv(img.clone().detach())
|
| 168 |
+
nm = tr.recognize(T.functional.to_pil_image(img_nm))
|
| 169 |
+
lv = tr.recognize(T.functional.to_pil_image(img_lv))
|
| 170 |
+
return nm,lv
|
| 171 |
+
def forward(self, img):
|
| 172 |
+
if len(img.shape) == 4:
|
| 173 |
+
return self.batch_forward(img)
|
| 174 |
+
elif len(img.shape) == 3:
|
| 175 |
+
return self.single_forward(img)
|
| 176 |
+
|
| 177 |
+
def forward_bak(self, img):
|
| 178 |
+
if len(img.shape) == 4:
|
| 179 |
+
img = img.squeeze(0)
|
| 180 |
+
img_nm = self.tsfm_nm(img.clone().detach())
|
| 181 |
+
img_lv = self.tsfm_lv(img.clone().detach())
|
| 182 |
+
nm = tr.recognize(T.functional.to_pil_image(img_nm))
|
| 183 |
+
lv = tr.recognize(T.functional.to_pil_image(img_lv))
|
| 184 |
+
return nm,lv
|
mhr/custom_transform.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mhr.common import *
|
| 2 |
+
|
| 3 |
+
class PIL2CV(torch.nn.Module):
|
| 4 |
+
def __init__(self):
|
| 5 |
+
super().__init__()
|
| 6 |
+
def forward(self, img):
|
| 7 |
+
return cv.cvtColor(np.asarray(img), cv.COLOR_RGB2BGR)
|
| 8 |
+
|
| 9 |
+
class CV2PIL(torch.nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super().__init__()
|
| 12 |
+
def forward(self, img):
|
| 13 |
+
return Image.fromarray(cv.cvtColor(img, cv.COLOR_BGR2RGB))
|
| 14 |
+
|
| 15 |
+
class Tensor2CV(torch.nn.Module):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
def forward(self, img):
|
| 19 |
+
if len(img.shape) == 4:
|
| 20 |
+
img = img.squeeze(0)
|
| 21 |
+
img = img.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
|
| 22 |
+
img = cv.cvtColor(img, cv.COLOR_RGB2BGR)
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CvtColor(torch.nn.Module):
|
| 27 |
+
def __init__(self, cvt):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self._cvt = cvt
|
| 30 |
+
def forward(self, img):
|
| 31 |
+
return cv.cvtColor(img, self._cvt)
|
| 32 |
+
|
| 33 |
+
class GaussianBlur(torch.nn.Module):
|
| 34 |
+
def __init__(self, kernel, sigma):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self._kernel = kernel
|
| 37 |
+
self._sigma = sigma
|
| 38 |
+
def forward(self, img):
|
| 39 |
+
return cv.GaussianBlur(img, self._kernel, self._sigma)
|
| 40 |
+
|
| 41 |
+
class EqualizeHist(torch.nn.Module):
|
| 42 |
+
def __init__(self):
|
| 43 |
+
super().__init__()
|
| 44 |
+
def forward(self, img):
|
| 45 |
+
return cv.equalizeHist(img)
|
| 46 |
+
|
| 47 |
+
class SobelBitwiseOrXY(torch.nn.Module):
|
| 48 |
+
def __init__(self):
|
| 49 |
+
super().__init__()
|
| 50 |
+
def forward(self, img):
|
| 51 |
+
sobelx = cv.Sobel(img, cv.CV_64F, 1, 0)
|
| 52 |
+
sobely = cv.Sobel(img, cv.CV_64F, 0, 1)
|
| 53 |
+
sobelx = cv.convertScaleAbs(sobelx)
|
| 54 |
+
sobely = cv.convertScaleAbs(sobely)
|
| 55 |
+
return cv.bitwise_or(sobelx, sobely)
|
| 56 |
+
|
| 57 |
+
class Threshold(torch.nn.Module):
|
| 58 |
+
def __init__(self, thresh, maxval, tt):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self._thresh = thresh
|
| 61 |
+
self._maxval = maxval
|
| 62 |
+
self._tt = tt
|
| 63 |
+
def forward(self, img):
|
| 64 |
+
_, tmp = cv.threshold(img, self._thresh, self._maxval, self._tt)
|
| 65 |
+
return tmp
|
| 66 |
+
|
| 67 |
+
class Cut(torch.nn.Module):
|
| 68 |
+
def __init__(self, point, offsetx, offsety):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self._p = point
|
| 71 |
+
self._offsetx = offsetx
|
| 72 |
+
self._offsety = offsety
|
| 73 |
+
def forward(self, img):
|
| 74 |
+
p = self._p
|
| 75 |
+
return img[p[0]:p[0]+self._offsetx, p[1]:p[1]+self._offsety ]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TensorCut(torch.nn.Module):
|
| 80 |
+
def __init__(self, point, offsetx, offsety):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self._p = point
|
| 83 |
+
self._offsetx = offsetx
|
| 84 |
+
self._offsety = offsety
|
| 85 |
+
def forward(self, img):
|
| 86 |
+
p = self._p
|
| 87 |
+
return img[:, p[0]:p[0]+self._offsetx, p[1]:p[1]+self._offsety ]
|
mhr/predict_tools.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mhr.config import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class MHRVedioCuter:
|
| 5 |
+
def __init__(self, speed_ratio=1):
|
| 6 |
+
## config
|
| 7 |
+
#self.part_pos = {
|
| 8 |
+
# 'pos':[(628,350),(993,565)],
|
| 9 |
+
# 'page': [(781,575),(848,600)],
|
| 10 |
+
# 'hole': [(1167,197),(1250,227)],
|
| 11 |
+
# 'skill': [(1010,260),(1254,600)],
|
| 12 |
+
#}
|
| 13 |
+
self.part_pos = {
|
| 14 |
+
'hole': [1166,200,28,26],
|
| 15 |
+
'skill': [1014,264,240,50],
|
| 16 |
+
}
|
| 17 |
+
self.speed_ratio = speed_ratio
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def iter(self, name):
|
| 21 |
+
vc = cv.VideoCapture(name)
|
| 22 |
+
fps = vc.get(cv.CAP_PROP_FPS)
|
| 23 |
+
print("vedio:", vc.isOpened(), fps)
|
| 24 |
+
label = "00:00:{:05.2f}({})"
|
| 25 |
+
rval = True
|
| 26 |
+
idx=0
|
| 27 |
+
while rval:
|
| 28 |
+
rval, img = vc.read()
|
| 29 |
+
idx+=1
|
| 30 |
+
if rval and idx%self.speed_ratio == 0:
|
| 31 |
+
yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
|
| 32 |
+
|
| 33 |
+
vc.release()
|
| 34 |
+
|
| 35 |
+
def _cut_whole(self, img):
|
| 36 |
+
#pos_w, pos_h, w, h = self.part_pos['skill']
|
| 37 |
+
return img
|
| 38 |
+
|
| 39 |
+
def _cut_hole(self, img):
|
| 40 |
+
pos_w, pos_h, w, h = self.part_pos['hole']
|
| 41 |
+
return [ img[pos_h:pos_h+h, pos_w+w*i:pos_w+w*i+w] for i in range(3) ]
|
| 42 |
+
|
| 43 |
+
def _cut_skill(self, img):
|
| 44 |
+
pos_w, pos_h, w, h = self.part_pos['skill']
|
| 45 |
+
return [ img[pos_h+h*i:pos_h+h*i+h, pos_w:pos_w+w] for i in range(7) ]
|
| 46 |
+
|
| 47 |
+
class MHRVedioSimpleCuter(MHRVedioCuter):
|
| 48 |
+
def __init__(self):
|
| 49 |
+
super(MHRVedioSimpleCuter, self).__init__()
|
| 50 |
+
|
| 51 |
+
def iter(self, v):
|
| 52 |
+
for img in v:
|
| 53 |
+
yield self._cut_whole(img), self._cut_hole(img), self._cut_skill(img), label.format(idx/fps, idx)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MHRStoneRecognizeMgr:
|
| 58 |
+
def __init__(self, whole_pkl, hole_pkl, vedio_cutter):
|
| 59 |
+
self.mapping_hole = [0,2,1,3,4]
|
| 60 |
+
cp = torch.load(whole_pkl)
|
| 61 |
+
self.whole_model = GaborFeatureNet(num_classes=2)
|
| 62 |
+
self.whole_model.load_state_dict(cp['model'])
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
self.whole_model = self.whole_model.cuda()
|
| 65 |
+
#self.whole_model = torch.load(whole_pkl)
|
| 66 |
+
self.hole_feat_model = GaborFeatureGen(0)
|
| 67 |
+
if torch.cuda.is_available():
|
| 68 |
+
self.hole_feat_model = self.hole_feat_model.cuda()
|
| 69 |
+
self.hole_model = pickle.load(open(hole_pkl, 'rb'))
|
| 70 |
+
self.skill_model = MyTrRecognizeNet(image_padding=2)
|
| 71 |
+
self._vedio_cutter = vedio_cutter
|
| 72 |
+
|
| 73 |
+
def recognize(self, vname, fname=None):
|
| 74 |
+
def dump(rr):
|
| 75 |
+
return '_'.join([ str(x) for x in rr[1] ]) + "|" + '|'.join([ x[0]+":"+x[1] for x in rr[2] ])
|
| 76 |
+
results = []
|
| 77 |
+
i=0
|
| 78 |
+
for data in self._vedio_cutter.iter(vname):
|
| 79 |
+
i+=1
|
| 80 |
+
#if i < 200 and i > 250:
|
| 81 |
+
# continue
|
| 82 |
+
#if len(results) == 1:
|
| 83 |
+
# break
|
| 84 |
+
#whole
|
| 85 |
+
data_whole = tsfm_whole4cv(data[0])
|
| 86 |
+
data_whole = data_whole.unsqueeze(0)
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
data_whole = data_whole.cuda()
|
| 89 |
+
ret = self.whole_model(data_whole)
|
| 90 |
+
if ret[0][1] - ret[0][0] < 2:
|
| 91 |
+
continue
|
| 92 |
+
ret_whole = True
|
| 93 |
+
#new hole
|
| 94 |
+
data_hole = torch.cat([ tsfm_hole4cv(item).unsqueeze(0) for item in data[1] ], dim=0)
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
data_hole = data_hole.cuda()
|
| 97 |
+
output = self.hole_feat_model(data_hole)
|
| 98 |
+
df = pd.DataFrame(output.tolist())
|
| 99 |
+
ret_hole = list(self.hole_model.predict(df))
|
| 100 |
+
#new skill
|
| 101 |
+
data_skill = torch.cat([ tsfm_skill4cv(item).unsqueeze(0) for item in data[2] ], dim=0)
|
| 102 |
+
ret = self.skill_model(data_skill)
|
| 103 |
+
ret_skill = [ (x[0][0], x[1][0][-1]) for x in filter(lambda sk: sk[0][1] > 0.9 and sk[1][1] > 0.9, ret) ]
|
| 104 |
+
# reuslt
|
| 105 |
+
result = [data[3], ret_hole, ret_skill]
|
| 106 |
+
if len(results) > 0 and dump(results[-1]) == dump(result):
|
| 107 |
+
continue
|
| 108 |
+
results.append(result)
|
| 109 |
+
if fname:
|
| 110 |
+
with open(fname, 'w') as f:
|
| 111 |
+
for result in results:
|
| 112 |
+
line = result[0]
|
| 113 |
+
line += ','
|
| 114 |
+
line += ','.join([ str(x) for x in result[1] ])
|
| 115 |
+
line += ','
|
| 116 |
+
line += ','.join([ x[0]+","+x[1] for x in result[2] ])
|
| 117 |
+
line += '\n'
|
| 118 |
+
f.write(line)
|
| 119 |
+
return results
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
mgr = MHRStoneRecognizeMgr(
|
| 124 |
+
whole_pkl = whole_pkl_file,
|
| 125 |
+
hole_pkl = hole_pkl_file,
|
| 126 |
+
vedio_cutter = MHRVedioCuter(speed_ratio),
|
| 127 |
+
)
|
whole_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:267011c5b8bb860dd61a4b3e769582d5f69296f9432a51b32b93ea8b168ccc72
|
| 3 |
+
size 61054629
|