File size: 2,748 Bytes
60a68a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f92fe7
 
 
60a68a6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastai.vision.all import load_learner, PILImage
import torch
import csv
import hashlib
import json
from pathlib import Path
import os


def get_preds(obj, learn, model_name='tags', thresh=15):

    labels = []

    '''
    get list of classes from Learner object
    '''
    for item in learn.dls.vocab:
        labels.append(item)

    '''
    open mapping from csv into dictionary and get only the onces with mapping
    '''
    if model_name == 'life-event':
        input_file = "./model/cardtagger/mapping-life-event.csv"

    else:
        input_file = "./model/cardtagger/mapping.csv"

    data = csv.DictReader(open(input_file))

    dic = dict()

    for row in data:
        if row['tag'] != row['alternatives']:
            dic[row['tag']] = row['alternatives'].split(',')

    '''
    combine the classnames with the result and get those with > threshold back 
    add the synonym mapping list to the dictionary
    '''
    predictions = []
    x = 0
    for item in obj:
        acc = round(item.item()*100, 1)

        if acc > thresh:

            synonyms = []

            for i in dic:
                if labels[x] == i:
                    synonyms = dic[i]

            predictions.append({"label": labels[x], "probability" : acc, "synonyms" : synonyms })
            #predictions[labels[x]] = acc
        x += 1

    predictions = {"predictions": predictions}


    return predictions


def cardtagger(image):

    img = PILImage(PILImage.create(image).resize((128,128)))

    '''
    get classification of images that already where send to api or predict on new
    '''
    base = Path("./tmp/")

    md5hash = hashlib.md5(img.tobytes()).hexdigest()

    file = os.path.join(base, md5hash)

    if os.path.exists(file):
        result = json.load(open(base / (md5hash)))

    else:

        '''
        get classification of tags
        '''

        tag_model = load_learner('./model/cardtagger/tags.pkl')

        tag_prediction, _, tag_probs = tag_model.predict(img)

        result_tags = get_preds(tag_probs, tag_model, 'tags')

        '''
        get classification of life event
        '''

        life_event_model = load_learner('./model/cardtagger/life-event-2.pkl')

        life_event_prediction, _, life_event_probs = life_event_model.predict(img)

        result_life = get_preds(life_event_probs, life_event_model, 'life-event', 30)


        '''
        comebine tag predictions ...
        '''
        result =  {"predictions": result_tags['predictions']+result_life['predictions']}

        '''
        write the json to a temp file and return the results
        '''


        # out_file = open(file, "w+")
        #
        # json.dump(result, out_file)


    return result

#cardtagger('test.jpg')