Spaces:
Build error
Build error
| from turtle import title | |
| import gradio as gr | |
| from transformers import pipeline | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| import pandas as pd | |
| from glob import glob | |
| import random | |
| from datetime import datetime | |
| import numpy as np | |
| from numpy.random import MT19937 | |
| from numpy.random import RandomState, SeedSequence | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| HYPERDIMS = 1024 | |
| VALUE_BITS = 8 | |
| POS_BITS = 9 # CLIP features are 512 dims | |
| val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS) | |
| print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins') | |
| def extract_features(image): | |
| PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') | |
| inputs = clip_processor(text=["a photo of a cat", "a photo of a dog"], images=PIL_image, return_tensors="pt", padding=True) | |
| outputs = clip_model(**inputs) | |
| # print(outputs.image_embeds.shape) | |
| return outputs.image_embeds | |
| def update_table(img, img_name, df, state, label, exemplars_state, lut_state): | |
| img_embeds = extract_features(img).detach().numpy().squeeze().tolist() | |
| print(img_name, img.shape, len(img_embeds), 'images left:', len(state)) | |
| new_df = pd.DataFrame({'image_name': img_name, 'label': label, 'image_embed': None}, columns=['image_name', 'image_embed', 'label'], index=[0]) | |
| # print(new_df) | |
| new_df.at[0, 'image_embed'] = img_embeds | |
| df = pd.concat([df, new_df]) | |
| filt = df["image_name"] != "" | |
| df = df[filt] | |
| state.pop() | |
| t = state[-10:] | |
| random.shuffle(t) | |
| state = state[:-10] + t | |
| idx = -1 | |
| next_img = state[idx] | |
| preds = predict(extract_features(img).detach().numpy(), exemplars_state, lut_state) | |
| return next_img, next_img, df, state, preds | |
| def update_table_up(img, img_name, df, state, exemplars_state, lut_state): | |
| return update_table(img, img_name, df, state, 1, exemplars_state, lut_state) | |
| def update_table_down(img, img_name, df, state, exemplars_state, lut_state): | |
| return update_table(img, img_name, df, state, 0, exemplars_state, lut_state) | |
| def make_LUT(nvalues, dims, rs): | |
| lut = np.zeros(shape=(nvalues, dims)) | |
| lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims)) | |
| for row in range(1, nvalues): | |
| lut[row, :] = lut[row-1, :] | |
| # flip few randomly | |
| rand_idx = rs.choice(dims, size=dims//nvalues, replace=False) | |
| lut[row, rand_idx] = 1 - lut[row, rand_idx] | |
| assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues | |
| unique_rows = np.unique(lut, axis=0) | |
| assert len(unique_rows) == len(lut) | |
| return lut | |
| def load_fn(images, rng_state, exemplars_state, lut_state): | |
| rs = RandomState(MT19937(SeedSequence(123456789))) | |
| rng_state[0] = rs | |
| exemplars_state[0] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) | |
| exemplars_state[1] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) | |
| lut_state[0] = make_LUT(2**VALUE_BITS, HYPERDIMS, rs) | |
| assert lut_state[0].shape[0] == val_bins.shape[0] | |
| lut_state[1] = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS)) | |
| print(exemplars_state) | |
| print(lut_state[0].shape, lut_state[1].shape) | |
| return images[-1], images[-1], rng_state, exemplars_state, lut_state | |
| def quantize_embeds(embeds): | |
| assert np.all(embeds >= val_bins[0]) | |
| assert np.all(embeds <= val_bins[-1]) | |
| embeds_flat = embeds.flatten() | |
| all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :]) | |
| closest_bin = np.argmin(all_pairs_dist, axis=-1) | |
| quantized_embeds_flat = val_bins[closest_bin] | |
| quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape) | |
| closest_bin = np.reshape(closest_bin, embeds.shape) | |
| print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max()) | |
| print('abs quant error avg', np.abs(embeds - quantized_embeds).mean()) | |
| return quantized_embeds, closest_bin | |
| def update_exemplars(df, rng, exemplars, lut): | |
| embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy() | |
| labels = np.array(df['label'].values.tolist(), 'int') | |
| # print(labels, labels.shape) | |
| assert np.all(np.unique(labels) == [0, 1]) | |
| labels_zero_idx = (labels == 0).nonzero()[0] | |
| labels_one_idx = (labels == 1).nonzero()[0] | |
| print(labels_zero_idx.shape, " zeros and ", labels_one_idx.shape, " ones") | |
| # 70-30 split | |
| labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False) | |
| labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False) | |
| embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0) | |
| labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0) | |
| print('Training set ', embeds_train.shape, labels_train.shape) | |
| print(np.sum(labels_train == 0), " zeros and ", np.sum(labels_train == 1).sum(), " ones") | |
| labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx) | |
| labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx) | |
| embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0) | |
| labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0) | |
| print('Test set ', embeds_test.shape, labels_test.shape) | |
| quantized_embeds, closest_bin = quantize_embeds(embeds_train) | |
| # closest bin is nexample X 512 | |
| # lut[0] is nvals X dims | |
| # hd_embeds in nexample x 512 x dims | |
| hd_embeds_per_pos = lut[0][closest_bin] | |
| # bundle along pos dimension 512 | |
| # lut[1] is 512 x dims | |
| xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
| hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
| hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1] | |
| hd_embeds[hd_embeds >= 0.5] = 1. | |
| hd_embeds[hd_embeds < 0.5] = 0. | |
| # hd_embeds_integer is nexample x dims | |
| exemplars_integer = [None, None] | |
| exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0) | |
| exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0) | |
| exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0) | |
| exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1) | |
| exemplars[0][exemplars[0] >= 0.5] = 1. | |
| exemplars[0][exemplars[0] < 0.5] = 0. | |
| exemplars[1][exemplars[1] >= 0.5] = 1. | |
| exemplars[1][exemplars[1] < 0.5] = 0. | |
| print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
| preds = np.zeros(hd_embeds.shape[0]) | |
| dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
| dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
| preds[dist_to_ex1 < dist_to_ex0] = 1 | |
| print(preds.shape, labels_train.shape, np.sum(preds == labels_train)) | |
| train_acc = np.sum(preds == labels_train) / len(labels_train) | |
| rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut) | |
| return rng, exemplars, train_acc, test_acc | |
| def score(embeds, labels, rng, exemplars, lut): | |
| quantized_embeds, closest_bin = quantize_embeds(embeds) | |
| # closest bin is nexample X 512 | |
| # lut[0] is nvals X dims | |
| # hd_embeds in nexample x 512 x dims | |
| hd_embeds_per_pos = lut[0][closest_bin] | |
| # bundle along pos dimension 512 | |
| # lut[1] is 512 x dims | |
| xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
| hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
| hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] | |
| hd_embeds[hd_embeds >= 0.5] = 1. | |
| hd_embeds[hd_embeds < 0.5] = 0. | |
| # hd_embeds_integer is nexample x dims | |
| print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
| preds = np.zeros(hd_embeds.shape[0]) | |
| dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
| dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
| preds[dist_to_ex1 < dist_to_ex0] = 1 | |
| print(preds.shape, labels.shape, np.sum(preds == labels), len(labels)) | |
| acc = np.sum(preds == labels) / len(labels) | |
| return rng, acc | |
| def predict(embeds, exemplars, lut): | |
| quantized_embeds, closest_bin = quantize_embeds(embeds) | |
| # closest bin is nexample X 512 | |
| # lut[0] is nvals X dims | |
| # hd_embeds in nexample x 512 x dims | |
| hd_embeds_per_pos = lut[0][closest_bin] | |
| # bundle along pos dimension 512 | |
| # lut[1] is 512 x dims | |
| xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
| hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
| hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] | |
| hd_embeds[hd_embeds >= 0.5] = 1. | |
| hd_embeds[hd_embeds < 0.5] = 0. | |
| # hd_embeds_integer is nexample x dims | |
| # print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
| dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
| dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
| print('dists', dist_to_ex0, dist_to_ex1) | |
| odds = np.abs(dist_to_ex0 - dist_to_ex1).item() | |
| if dist_to_ex1 < dist_to_ex0: | |
| preds = np.array([1., odds]) | |
| else: | |
| preds = np.array([odds, 1.]) | |
| print(preds) | |
| # preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1]) | |
| preds = preds / preds.sum() | |
| # print(preds.shape) | |
| print(preds) | |
| return {"👍": preds[1], "👎": preds[0]} | |
| with gr.Blocks(title="End-User Personalization") as demo: | |
| img_list = glob('images/**/*.jpg') | |
| random.seed(datetime.now().timestamp()) | |
| random.shuffle(img_list) | |
| images = gr.State(img_list) | |
| # start_button = gr.Button(label="Start") | |
| with gr.Row(): | |
| image_display = gr.Image() | |
| with gr.Column(): | |
| image_fname = gr.Textbox() | |
| preds = gr.Label("Prediction") | |
| # text_display = gr.Text() | |
| with gr.Row(): | |
| upvote = gr.Button("👍") | |
| downvote = gr.Button("👎") | |
| personalize = gr.Button("Personalize") | |
| with gr.Row(): | |
| train_acc = gr.Textbox(label="Train accuracy") | |
| test_acc = gr.Textbox(label="Test accuracy") | |
| annotated_samples = gr.Dataframe(headers=['image_name', 'label', 'image_embed'], row_count=(1, 'dynamic'), | |
| col_count=(3, 'fixed'), label='Annotations', wrap=False) | |
| # HD stuff for incremental updates | |
| rng = gr.State([None]) | |
| exemplars_state = gr.State([None, None]) | |
| exemplars_state_integer = gr.State([None, None]) | |
| lut_state = gr.State([None, None]) | |
| upvote.click(update_table_up, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) | |
| downvote.click(update_table_down, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) | |
| personalize.click(update_exemplars, [annotated_samples, rng, exemplars_state, lut_state], [rng, exemplars_state, train_acc, test_acc]) | |
| demo.load(load_fn, inputs=[images, rng, exemplars_state, lut_state], outputs=[image_display, image_fname, rng, exemplars_state, lut_state]) | |
| demo.launch(show_error=True, debug=True) |