Spaces:
Runtime error
Runtime error
| import torch | |
| import clip | |
| from PIL import Image | |
| import glob | |
| import os | |
| from random import choice | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-L/14@336px", device=device) | |
| COCO = glob.glob(os.path.join(os.getcwd(), "images", "*")) | |
| available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] | |
| def load_random_image(): | |
| image_path = choice(COCO) | |
| image = Image.open(image_path) | |
| return image | |
| def next_image(): | |
| global image_org, image | |
| image_org = load_random_image() | |
| image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device) | |
| # def calculate_logits(image, text): | |
| # return model(image, text)[0] | |
| def calculate_logits(image_features, text_features): | |
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
| logit_scale = model.logit_scale.exp() | |
| return logit_scale * image_features @ text_features.t() | |
| last = -1 | |
| best = -1 | |
| goal = 23 | |
| image_org = load_random_image() | |
| image = preprocess(image_org).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| def answer(message): | |
| global last, best | |
| text = clip.tokenize([message]).to(device) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text) | |
| # logits_per_image, _ = model(image, text) | |
| logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0] | |
| # logits = calculate_logits(image, text) | |
| if last == -1: | |
| is_better = -1 | |
| elif last > logits: | |
| is_better = 0 | |
| elif last < logits: | |
| is_better = 1 | |
| elif logits > goal: | |
| is_better = 2 | |
| else: | |
| is_better = -1 | |
| last = logits | |
| if logits > best: | |
| best = logits | |
| is_better = 3 | |
| return logits, is_better | |
| def reset_everything(): | |
| global last, best, goal, image, image_org | |
| last = -1 | |
| best = -1 | |
| goal = 23 | |
| image_org = load_random_image() | |
| image = preprocess(image_org).unsqueeze(0).to(device) | |