|
|
from typing import Dict, List, Any |
|
|
from PIL import Image |
|
|
|
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
from fastai.learner import load_learner |
|
|
|
|
|
|
|
|
class OrdinalRegressionMetric(Metric): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.total = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def accumulate(self, learn): |
|
|
|
|
|
preds, targs = learn.pred, learn.y |
|
|
|
|
|
|
|
|
preds_numeric = torch.argmax(preds, dim=1) |
|
|
targs_numeric = targs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
squared_diff = torch.sum(torch.sqrt((preds_numeric - targs_numeric)**2)) |
|
|
|
|
|
|
|
|
max_diff = torch.sqrt((torch.max(targs_numeric) - torch.min(targs_numeric))**2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.total += squared_diff |
|
|
|
|
|
self.count += max_diff |
|
|
|
|
|
@property |
|
|
def value(self): |
|
|
if self.count == 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
metric_value = 1/(self.total / self.count) |
|
|
return metric_value |
|
|
|
|
|
class PreTrainedPipeline(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.metric = OrdinalRegressionMetric() |
|
|
self.model = load_learner(os.path.join(path, "textfile3-2.pk1")) |
|
|
with open(os.path.join(path, "config.json")) as config: |
|
|
config = json.load(config) |
|
|
self.id2label = config["id2label"] |
|
|
|
|
|
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Args: |
|
|
inputs (:obj:`PIL.Image`): |
|
|
The raw image representation as PIL. |
|
|
No transformation made whatsoever from the input. Make all necessary transformations here. |
|
|
Return: |
|
|
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} |
|
|
It is preferred if the returned list is in decreasing `score` order |
|
|
""" |
|
|
|
|
|
|
|
|
_, _, preds = self.model.predict(np.array(inputs)) |
|
|
preds = preds.tolist() |
|
|
labels = [ |
|
|
{"label": str(self.id2label["0"]), "score": preds[0]}, |
|
|
{"label": str(self.id2label["1"]), "score": preds[1]}, |
|
|
{"label": str(self.id2label["2"]), "score": preds[2]}, |
|
|
{"label": str(self.id2label["3"]), "score": preds[3]}, |
|
|
{"label": str(self.id2label["4"]), "score": preds[4]}, |
|
|
{"label": str(self.id2label["5"]), "score": preds[5]}, |
|
|
{"label": str(self.id2label["6"]), "score": preds[6]}, |
|
|
{"label": str(self.id2label["7"]), "score": preds[7]}, |
|
|
{"label": str(self.id2label["8"]), "score": preds[8]}, |
|
|
] |
|
|
return labels |
|
|
|