Spaces:
Sleeping
Sleeping
| import torch | |
| import clip | |
| import os | |
| import numpy as np | |
| imagenet_templates = [ | |
| 'a bad photo of a {}.', | |
| # 'a photo of many {}.', | |
| 'a sculpture of a {}.', | |
| 'a photo of the hard to see {}.', | |
| 'a low resolution photo of the {}.', | |
| 'a rendering of a {}.', | |
| 'graffiti of a {}.', | |
| 'a bad photo of the {}.', | |
| 'a cropped photo of the {}.', | |
| 'a tattoo of a {}.', | |
| 'the embroidered {}.', | |
| 'a photo of a hard to see {}.', | |
| 'a bright photo of a {}.', | |
| 'a photo of a clean {}.', | |
| 'a photo of a dirty {}.', | |
| 'a dark photo of the {}.', | |
| 'a drawing of a {}.', | |
| 'a photo of my {}.', | |
| 'the plastic {}.', | |
| 'a photo of the cool {}.', | |
| 'a close-up photo of a {}.', | |
| 'a black and white photo of the {}.', | |
| 'a painting of the {}.', | |
| 'a painting of a {}.', | |
| 'a pixelated photo of the {}.', | |
| 'a sculpture of the {}.', | |
| 'a bright photo of the {}.', | |
| 'a cropped photo of a {}.', | |
| 'a plastic {}.', | |
| 'a photo of the dirty {}.', | |
| 'a jpeg corrupted photo of a {}.', | |
| 'a blurry photo of the {}.', | |
| 'a photo of the {}.', | |
| 'a good photo of the {}.', | |
| 'a rendering of the {}.', | |
| 'a {} in a video game.', | |
| 'a photo of one {}.', | |
| 'a doodle of a {}.', | |
| 'a close-up photo of the {}.', | |
| 'a photo of a {}.', | |
| 'the origami {}.', | |
| 'the {} in a video game.', | |
| 'a sketch of a {}.', | |
| 'a doodle of the {}.', | |
| 'a origami {}.', | |
| 'a low resolution photo of a {}.', | |
| 'the toy {}.', | |
| 'a rendition of the {}.', | |
| 'a photo of the clean {}.', | |
| 'a photo of a large {}.', | |
| 'a rendition of a {}.', | |
| 'a photo of a nice {}.', | |
| 'a photo of a weird {}.', | |
| 'a blurry photo of a {}.', | |
| 'a cartoon {}.', | |
| 'art of a {}.', | |
| 'a sketch of the {}.', | |
| 'a embroidered {}.', | |
| 'a pixelated photo of a {}.', | |
| 'itap of the {}.', | |
| 'a jpeg corrupted photo of the {}.', | |
| 'a good photo of a {}.', | |
| 'a plushie {}.', | |
| 'a photo of the nice {}.', | |
| 'a photo of the small {}.', | |
| 'a photo of the weird {}.', | |
| 'the cartoon {}.', | |
| 'art of the {}.', | |
| 'a drawing of the {}.', | |
| 'a photo of the large {}.', | |
| 'a black and white photo of a {}.', | |
| 'the plushie {}.', | |
| 'a dark photo of a {}.', | |
| 'itap of a {}.', | |
| 'graffiti of the {}.', | |
| 'a toy {}.', | |
| 'itap of my {}.', | |
| 'a photo of a cool {}.', | |
| 'a photo of a small {}.', | |
| 'a tattoo of the {}.', | |
| ] | |
| def zeroshot_classifier(classnames, templates,model): | |
| with torch.no_grad(): | |
| zeroshot_weights = [] | |
| for classname in classnames: | |
| texts = [template.format(classname) for template in templates] #format with class | |
| texts = clip.tokenize(texts).cuda() #tokenize | |
| class_embeddings = model.encode_text(texts) #embed with text encoder | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings.mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() | |
| return zeroshot_weights | |
| def GetDt(classnames,model): | |
| text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() | |
| dt=text_features[0]-text_features[1] | |
| dt=dt.cpu().numpy() | |
| return dt | |