Spaces:
Running
Running
| import base64 | |
| from uuid import uuid4 | |
| import gradio as gr | |
| from fastcore.all import * | |
| from fastai.vision.all import * | |
| import numpy as np | |
| import timm | |
| def parent_labels(o): | |
| "Label `item` with the parent folder name." | |
| return Path(o).parent.name.split(",") | |
| class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat): | |
| def __init__(self, eps:float=0.1, **kwargs): | |
| self.eps = eps | |
| super().__init__(thresh=0.1, **kwargs) | |
| def __call__(self, inp, targ, **kwargs): | |
| targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps | |
| return super().__call__(inp, targ_smooth, **kwargs) | |
| learn = load_learner('models.pkl') | |
| # set a new loss function with a threshold of 0.4 to remove more false positives | |
| learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
| def predict_tags(image, vtt, threshold=0.4): | |
| vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) | |
| sprite = PILImage.create(image) | |
| offsets = [] | |
| times = [] | |
| images = [] | |
| frames = [] | |
| for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): | |
| frames.append(i) | |
| times.append(time_seconds) | |
| offsets.append((left, top, right, bottom)) | |
| cut_frame = sprite.crop((left, top, left + right, top + bottom)) | |
| images.append(PILImage.create(np.asarray(cut_frame))) | |
| # create dataset | |
| threshold = threshold or 0.4 | |
| learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) | |
| test_dl = learn.dls.test_dl(images, bs=64) | |
| # get predictions | |
| probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) | |
| learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
| # swivel into tags list from activations | |
| tags = {} | |
| for idx1, activation in enumerate(activations): | |
| for idx2, i in enumerate(activation): | |
| if not i: | |
| continue | |
| tag = learn.dls.vocab[idx2] | |
| tag = tag.replace("_", " ") | |
| if tag not in tags: | |
| tags[tag] = {'prob': 0, 'offset': (), 'frame': 0} | |
| prob = float(probabilities[idx1][idx2]) | |
| if tags[tag]['prob'] < prob: | |
| tags[tag]['prob'] = prob | |
| tags[tag]['offset'] = offsets[idx1] | |
| tags[tag]['frame'] = idx1 | |
| tags[tag]['time'] = times[idx1] | |
| return tags | |
| def predict_markers(image, vtt, threshold=0.4): | |
| vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) | |
| sprite = PILImage.create(image) | |
| offsets = [] | |
| times = [] | |
| images = [] | |
| frames = [] | |
| for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): | |
| frames.append(i) | |
| times.append(time_seconds) | |
| offsets.append((left, top, right, bottom)) | |
| cut_frame = sprite.crop((left, top, left + right, top + bottom)) | |
| images.append(PILImage.create(np.asarray(cut_frame))) | |
| # create dataset | |
| threshold = threshold or 0.4 | |
| learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) | |
| test_dl = learn.dls.test_dl(images, bs=64) | |
| # get predictions | |
| probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) | |
| learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
| # swivel into tags list from activations | |
| all_data_per_frame = [] | |
| for idx1, activation in enumerate(activations): | |
| frame_data = {'offset': offsets[idx1], 'frame': idx1, 'time': times[idx1], 'tags': []} | |
| ftags = [] | |
| for idx2, i in enumerate(activation): | |
| if not i: | |
| continue | |
| tag = learn.dls.vocab[idx2] | |
| tag = tag.replace("_", " ") | |
| prob = float(probabilities[idx1][idx2]) | |
| ftags.append({'label': tag, 'prob': prob}) | |
| if not ftags: | |
| continue | |
| frame_data['tags'] = ftags | |
| all_data_per_frame.append(frame_data) | |
| filtered = [] | |
| for idx, frame_data in enumerate(all_data_per_frame): | |
| if idx == len(all_data_per_frame) - 1: | |
| break | |
| next_frame_data = all_data_per_frame[idx + 1] | |
| frame_data['tags'] = [tag for tag in frame_data['tags'] for next_tag in next_frame_data['tags'] if tag['label'] == next_tag['label']] | |
| if frame_data['tags']: | |
| filtered.append(frame_data) | |
| last_tag = set() | |
| results = [] | |
| for frame_data in filtered: | |
| tags = {s['label'] for s in frame_data['tags']} | |
| if tags.intersection(last_tag): | |
| continue | |
| last_tag = tags | |
| frame_data['tag'] = sorted(frame_data['tags'], key=lambda x: x['prob'], reverse=True)[0] | |
| del frame_data['tags'] | |
| # add unique id to the frame | |
| frame_data['id'] = str(uuid4()) | |
| results.append(frame_data) | |
| return results | |
| def getVTToffsets(vtt): | |
| time_seconds = 0 | |
| left = top = right = bottom = None | |
| for line in vtt.decode("utf-8").split("\n"): | |
| line = line.strip() | |
| if "-->" in line: | |
| # grab the start time | |
| # 00:00:00.000 --> 00:00:41.000 | |
| start = line.split("-->")[0].strip().split(":") | |
| # convert to seconds | |
| time_seconds = ( | |
| int(start[0]) * 3600 | |
| + int(start[1]) * 60 | |
| + float(start[2]) | |
| ) | |
| left = top = right = bottom = None | |
| elif "xywh=" in line: | |
| left, top, right, bottom = line.split("xywh=")[-1].split(",") | |
| left, top, right, bottom = ( | |
| int(left), | |
| int(top), | |
| int(right), | |
| int(bottom), | |
| ) | |
| else: | |
| continue | |
| if not left: | |
| continue | |
| yield left, top, right, bottom, time_seconds | |
| # create a gradio interface with 2 tabs | |
| tag = gr.Interface( | |
| fn=predict_tags, | |
| inputs=[ | |
| gr.Image(), | |
| gr.Textbox(label="VTT file"), | |
| gr.Number(value=0.4, label="Threshold") | |
| ], | |
| outputs=gr.JSON(label=""), | |
| ) | |
| marker = gr.Interface( | |
| fn=predict_markers, | |
| inputs=[ | |
| gr.Image(), | |
| gr.Textbox(label="VTT file"), | |
| gr.Number(value=0.4, label="Threshold") | |
| ], | |
| outputs=gr.JSON(label=""), | |
| ) | |
| gr.TabbedInterface( | |
| [tag, marker], ["tag", "marker"] | |
| ).launch(server_name="0.0.0.0") | |