Spaces:
Runtime error
Runtime error
| from fastai.vision.all import load_learner, PILImage | |
| import torch | |
| import csv | |
| import hashlib | |
| import json | |
| from pathlib import Path | |
| import os | |
| def get_preds(obj, learn, model_name='tags', thresh=15): | |
| labels = [] | |
| ''' | |
| get list of classes from Learner object | |
| ''' | |
| for item in learn.dls.vocab: | |
| labels.append(item) | |
| ''' | |
| open mapping from csv into dictionary and get only the onces with mapping | |
| ''' | |
| if model_name == 'life-event': | |
| input_file = "./model/cardtagger/mapping-life-event.csv" | |
| else: | |
| input_file = "./model/cardtagger/mapping.csv" | |
| data = csv.DictReader(open(input_file)) | |
| dic = dict() | |
| for row in data: | |
| if row['tag'] != row['alternatives']: | |
| dic[row['tag']] = row['alternatives'].split(',') | |
| ''' | |
| combine the classnames with the result and get those with > threshold back | |
| add the synonym mapping list to the dictionary | |
| ''' | |
| predictions = [] | |
| x = 0 | |
| for item in obj: | |
| acc = round(item.item()*100, 1) | |
| if acc > thresh: | |
| synonyms = [] | |
| for i in dic: | |
| if labels[x] == i: | |
| synonyms = dic[i] | |
| predictions.append({"label": labels[x], "probability" : acc, "synonyms" : synonyms }) | |
| #predictions[labels[x]] = acc | |
| x += 1 | |
| predictions = {"predictions": predictions} | |
| return predictions | |
| def cardtagger(image): | |
| img = PILImage(PILImage.create(image).resize((128,128))) | |
| ''' | |
| get classification of images that already where send to api or predict on new | |
| ''' | |
| base = Path("./tmp/") | |
| md5hash = hashlib.md5(img.tobytes()).hexdigest() | |
| file = os.path.join(base, md5hash) | |
| if os.path.exists(file): | |
| result = json.load(open(base / (md5hash))) | |
| else: | |
| ''' | |
| get classification of tags | |
| ''' | |
| tag_model = load_learner('./model/cardtagger/tags.pkl') | |
| tag_prediction, _, tag_probs = tag_model.predict(img) | |
| result_tags = get_preds(tag_probs, tag_model, 'tags') | |
| ''' | |
| get classification of life event | |
| ''' | |
| life_event_model = load_learner('./model/cardtagger/life-event-2.pkl') | |
| life_event_prediction, _, life_event_probs = life_event_model.predict(img) | |
| result_life = get_preds(life_event_probs, life_event_model, 'life-event', 30) | |
| ''' | |
| comebine tag predictions ... | |
| ''' | |
| result = {"predictions": result_tags['predictions']+result_life['predictions']} | |
| ''' | |
| write the json to a temp file and return the results | |
| ''' | |
| # out_file = open(file, "w+") | |
| # | |
| # json.dump(result, out_file) | |
| return result | |
| #cardtagger('test.jpg') | |