Spaces:
Runtime error
Runtime error
lyangas commited on
Commit ·
488bb56
1
Parent(s): 1efad19
add method predict_code for prediction code by group
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ try:
|
|
| 19 |
except Exception as e:
|
| 20 |
print(f"ERROR: loading embedder failed with: {str(e)}")
|
| 21 |
|
| 22 |
-
|
| 23 |
classifiers_codes = {}
|
| 24 |
try:
|
| 25 |
for clf_name in os.listdir('classifiers/codes'):
|
|
@@ -28,10 +28,11 @@ try:
|
|
| 28 |
with open('classifiers/codes/'+clf_name, 'rb') as f:
|
| 29 |
model = pickle.load(f)
|
| 30 |
classifiers_codes[clf_name.split('.')[0]] = model
|
| 31 |
-
print(f'INFO: classifier {clf_name} loaded')
|
| 32 |
except Exception as e:
|
| 33 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
| 34 |
|
|
|
|
| 35 |
classifiers_groups = {}
|
| 36 |
try:
|
| 37 |
for clf_name in os.listdir('classifiers/groups'):
|
|
@@ -40,7 +41,21 @@ try:
|
|
| 40 |
with open('classifiers/groups/'+clf_name, 'rb') as f:
|
| 41 |
model = pickle.load(f)
|
| 42 |
classifiers_groups[clf_name.split('.')[0]] = model
|
| 43 |
-
print(f'INFO: classifier {clf_name} loaded')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
except Exception as e:
|
| 45 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
| 46 |
|
|
@@ -68,6 +83,17 @@ def classify_group(text, top_n):
|
|
| 68 |
preds[clf_name] = clf_preds
|
| 69 |
return preds
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def get_top_result(preds):
|
| 72 |
total_scores = {}
|
| 73 |
for clf_name, scores in preds.items():
|
|
@@ -97,7 +123,7 @@ def test():
|
|
| 97 |
return {'response': data}
|
| 98 |
|
| 99 |
@app.route("/predict", methods=['POST'])
|
| 100 |
-
def
|
| 101 |
data = request.json
|
| 102 |
base64_bytes = str(data['textB64']).encode("ascii")
|
| 103 |
sample_string_bytes = base64.b64decode(base64_bytes)
|
|
@@ -121,5 +147,28 @@ def read_root():
|
|
| 121 |
}
|
| 122 |
return result
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if __name__ == "__main__":
|
| 125 |
app.run(host='0.0.0.0', port=7860)
|
|
|
|
| 19 |
except Exception as e:
|
| 20 |
print(f"ERROR: loading embedder failed with: {str(e)}")
|
| 21 |
|
| 22 |
+
print('Loading classifiers of codes')
|
| 23 |
classifiers_codes = {}
|
| 24 |
try:
|
| 25 |
for clf_name in os.listdir('classifiers/codes'):
|
|
|
|
| 28 |
with open('classifiers/codes/'+clf_name, 'rb') as f:
|
| 29 |
model = pickle.load(f)
|
| 30 |
classifiers_codes[clf_name.split('.')[0]] = model
|
| 31 |
+
print(f'INFO: codes classifier {clf_name} loaded')
|
| 32 |
except Exception as e:
|
| 33 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
| 34 |
|
| 35 |
+
print('Loading classifiers of groups')
|
| 36 |
classifiers_groups = {}
|
| 37 |
try:
|
| 38 |
for clf_name in os.listdir('classifiers/groups'):
|
|
|
|
| 41 |
with open('classifiers/groups/'+clf_name, 'rb') as f:
|
| 42 |
model = pickle.load(f)
|
| 43 |
classifiers_groups[clf_name.split('.')[0]] = model
|
| 44 |
+
print(f'INFO: groups classifier {clf_name} loaded')
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
| 47 |
+
|
| 48 |
+
print('Loading classifiers in groups')
|
| 49 |
+
groups_models = {}
|
| 50 |
+
try:
|
| 51 |
+
for clf_name in os.listdir('classifiers/codes_in_groups'):
|
| 52 |
+
if '.' == clf_name[0]:
|
| 53 |
+
continue
|
| 54 |
+
with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f:
|
| 55 |
+
model = pickle.load(f)
|
| 56 |
+
group_name = clf_name.replace('_code_clf.pkl', '')
|
| 57 |
+
groups_models[group_name] = model
|
| 58 |
+
print(f'INFO: codes classifier for group {group_name} loaded')
|
| 59 |
except Exception as e:
|
| 60 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
| 61 |
|
|
|
|
| 83 |
preds[clf_name] = clf_preds
|
| 84 |
return preds
|
| 85 |
|
| 86 |
+
def classify_code_by_group(text, group_name, top_n):
|
| 87 |
+
embed = [embedder(text)]
|
| 88 |
+
model = groups_models[group_name]
|
| 89 |
+
probs = model.predict_proba(embed)
|
| 90 |
+
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
|
| 91 |
+
|
| 92 |
+
top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
|
| 93 |
+
top_cls = model.classes_[best_n[0]]
|
| 94 |
+
all_codes_in_group = model.classes_
|
| 95 |
+
return top_cls, top_n_preds, all_codes_in_group
|
| 96 |
+
|
| 97 |
def get_top_result(preds):
|
| 98 |
total_scores = {}
|
| 99 |
for clf_name, scores in preds.items():
|
|
|
|
| 123 |
return {'response': data}
|
| 124 |
|
| 125 |
@app.route("/predict", methods=['POST'])
|
| 126 |
+
def predict_api():
|
| 127 |
data = request.json
|
| 128 |
base64_bytes = str(data['textB64']).encode("ascii")
|
| 129 |
sample_string_bytes = base64.b64decode(base64_bytes)
|
|
|
|
| 147 |
}
|
| 148 |
return result
|
| 149 |
|
| 150 |
+
@app.route("/predict_code", methods=['POST'])
|
| 151 |
+
def predict_code_api():
|
| 152 |
+
data = request.json
|
| 153 |
+
base64_bytes = str(data['textB64']).encode("ascii")
|
| 154 |
+
sample_string_bytes = base64.b64decode(base64_bytes)
|
| 155 |
+
text = sample_string_bytes.decode("ascii")
|
| 156 |
+
top_n = int(data['top_n'])
|
| 157 |
+
group_name = data['dx_group']
|
| 158 |
+
|
| 159 |
+
if top_n < 1:
|
| 160 |
+
return {'error': 'top_n should be geather than 0'}
|
| 161 |
+
if text.strip() == '':
|
| 162 |
+
return {'error': 'text is empty'}
|
| 163 |
+
if group_name not in groups_models:
|
| 164 |
+
return {'error': 'have no classifier for the group'}
|
| 165 |
+
|
| 166 |
+
top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n)
|
| 167 |
+
result = {
|
| 168 |
+
"icd10":
|
| 169 |
+
{'result': top_pred_code, 'details': pred_codes, 'all_codes': all_codes_in_group}
|
| 170 |
+
}
|
| 171 |
+
return result
|
| 172 |
+
|
| 173 |
if __name__ == "__main__":
|
| 174 |
app.run(host='0.0.0.0', port=7860)
|