CardTagger / cardtagger.py
koenverhagen's picture
Update cardtagger.py
5f92fe7
raw
history blame
2.75 kB
from fastai.vision.all import load_learner, PILImage
import torch
import csv
import hashlib
import json
from pathlib import Path
import os
def get_preds(obj, learn, model_name='tags', thresh=15):
labels = []
'''
get list of classes from Learner object
'''
for item in learn.dls.vocab:
labels.append(item)
'''
open mapping from csv into dictionary and get only the onces with mapping
'''
if model_name == 'life-event':
input_file = "./model/cardtagger/mapping-life-event.csv"
else:
input_file = "./model/cardtagger/mapping.csv"
data = csv.DictReader(open(input_file))
dic = dict()
for row in data:
if row['tag'] != row['alternatives']:
dic[row['tag']] = row['alternatives'].split(',')
'''
combine the classnames with the result and get those with > threshold back
add the synonym mapping list to the dictionary
'''
predictions = []
x = 0
for item in obj:
acc = round(item.item()*100, 1)
if acc > thresh:
synonyms = []
for i in dic:
if labels[x] == i:
synonyms = dic[i]
predictions.append({"label": labels[x], "probability" : acc, "synonyms" : synonyms })
#predictions[labels[x]] = acc
x += 1
predictions = {"predictions": predictions}
return predictions
def cardtagger(image):
img = PILImage(PILImage.create(image).resize((128,128)))
'''
get classification of images that already where send to api or predict on new
'''
base = Path("./tmp/")
md5hash = hashlib.md5(img.tobytes()).hexdigest()
file = os.path.join(base, md5hash)
if os.path.exists(file):
result = json.load(open(base / (md5hash)))
else:
'''
get classification of tags
'''
tag_model = load_learner('./model/cardtagger/tags.pkl')
tag_prediction, _, tag_probs = tag_model.predict(img)
result_tags = get_preds(tag_probs, tag_model, 'tags')
'''
get classification of life event
'''
life_event_model = load_learner('./model/cardtagger/life-event-2.pkl')
life_event_prediction, _, life_event_probs = life_event_model.predict(img)
result_life = get_preds(life_event_probs, life_event_model, 'life-event', 30)
'''
comebine tag predictions ...
'''
result = {"predictions": result_tags['predictions']+result_life['predictions']}
'''
write the json to a temp file and return the results
'''
# out_file = open(file, "w+")
#
# json.dump(result, out_file)
return result
#cardtagger('test.jpg')