not-lain commited on
Commit
3051d55
·
verified ·
1 Parent(s): 5a0af5a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +130 -2
README.md CHANGED
@@ -7,5 +7,133 @@ tags:
7
  ---
8
 
9
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
10
- - Library: https://github.com/Arabic-Clip/Araclip_Enhanced
11
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ---
8
 
9
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
10
+ - Library: coming soon ✨
11
+
12
+ ## How to use
13
+ ```
14
+ pip install transformers open_clip_torch timm "huggingface_hub>=0.29.0"
15
+ ```
16
+
17
+ ```python
18
+ import numpy as np
19
+ from PIL import Image
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torchvision.transforms as transforms
24
+
25
+ from huggingface_hub import PyTorchModelHubMixin
26
+ from transformers import BertConfig, BertModel, AutoTokenizer
27
+ from open_clip import (
28
+ create_model,
29
+ )
30
+
31
+
32
+ class MultilingualClipEdited(nn.Module, PyTorchModelHubMixin):
33
+ def __init__(
34
+ self, transformer_cfg, in_features, out_features, tokenizer_repo_id_or_path
35
+ ):
36
+ super().__init__()
37
+ self.transformer = BertModel(BertConfig(**transformer_cfg))
38
+ self.clip_head = nn.Linear(in_features=in_features, out_features=out_features)
39
+ self.tokenizer = AutoTokenizer.from_pretrained(
40
+ tokenizer_repo_id_or_path,
41
+ )
42
+
43
+ def forward(self, txt):
44
+ txt_tok = self.tokenizer(txt, padding=True, return_tensors="pt")
45
+ embs = self.transformer(**txt_tok)[0]
46
+ att = txt_tok["attention_mask"]
47
+ embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
48
+ return self.clip_head(embs)
49
+
50
+
51
+ class AraClip(
52
+ nn.Module,
53
+ PyTorchModelHubMixin,
54
+ library_name="araclip",
55
+ repo_url="https://github.com/Arabic-Clip/Araclip_Enhanced",
56
+ tags=["clip"],
57
+ ):
58
+ def __init__(
59
+ self,
60
+ transformer_cfg,
61
+ in_features,
62
+ out_features,
63
+ tokenizer_repo_id_or_path="Arabic-Clip/bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M",
64
+ ):
65
+ super().__init__()
66
+ self.text_model = MultilingualClipEdited(
67
+ transformer_cfg,
68
+ in_features,
69
+ out_features,
70
+ tokenizer_repo_id_or_path,
71
+ )
72
+
73
+ self.clip_model = create_model("ViT-B-16-SigLIP-512", pretrained_hf=False)
74
+ self.compose = transforms.Compose(
75
+ [
76
+ transforms.Resize(
77
+ (512, 512),
78
+ interpolation=transforms.InterpolationMode.BICUBIC,
79
+ antialias=True,
80
+ ),
81
+ transforms.Lambda(lambda img: img.convert("RGB")),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
84
+ ],
85
+ )
86
+
87
+ def language_model(self, queries):
88
+ return np.asarray(self.text_model(queries).detach().to("cpu"))
89
+
90
+ def embed(self, text: str = None, image: Image.Image = None):
91
+ if text is None and image is None:
92
+ raise ValueError("Please provide either text or image input")
93
+
94
+ if text is not None and image is not None:
95
+ text_features = self.language_model([text])[0]
96
+ text_features = text_features / np.linalg.norm(text_features)
97
+
98
+ img_tensor = self.compose(image).unsqueeze(0)
99
+ with torch.no_grad():
100
+ image_features = self.clip_model.encode_image(img_tensor)
101
+ image_features = image_features.squeeze(0).cpu().numpy()
102
+ image_features = image_features / np.linalg.norm(image_features)
103
+
104
+ return text_features, image_features
105
+
106
+ elif text is not None:
107
+ text_features = self.language_model([text])[0]
108
+ return text_features / np.linalg.norm(text_features)
109
+
110
+ else:
111
+ img_tensor = self.compose(image).unsqueeze(0)
112
+ with torch.no_grad():
113
+ image_features = self.clip_model.encode_image(img_tensor)
114
+ image_features = image_features.squeeze(0).cpu().numpy()
115
+ return image_features / np.linalg.norm(image_features)
116
+ ```
117
+
118
+ ```python
119
+ # load model
120
+ model = AraClip.from_pretrained("Arabic-Clip/araclip")
121
+
122
+ # data
123
+ labels = ["قطة جالسة", "قطة تقفز" ,"كلب", "حصان"]
124
+ image = Image.open("cat.png")
125
+
126
+ # embed data
127
+ image_features = araclip.embed(image=image)
128
+ text_features = np.stack([araclip.embed(text=label) for label in labels])
129
+
130
+ # search for most similar data
131
+ similarities = text_features @ image_features
132
+ best_match = labels[np.argmax(similarities)]
133
+
134
+ print(f"The image is most similar to: {best_match}")
135
+ # قطة جالسة
136
+ ```
137
+
138
+
139
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6527e89a8808d80ccff88b7a/d5i4ItET9AZN9xgv8ify5.png)