| ### Model Summary | |
| The checkpoint aligns with our pixel-linguist-all setting in the paper. The model is initialized from our monolingual model, and is trained on parallel data (205000 steps) <-> AllNLI (2600 steps), going back and forth for three rounds. This model is the last round checkpoint. We recommend using it with A100 GPU, aligning with training. | |
| ### Downstream Use | |
| Semantic Textual Similarity, Information Retrieval, Reasoning Retrieval | |
| ### Out-of-Scope Use | |
| The model might not be optimal for further fine-tuning to do other tasks (such as classification), as it's trained to do representation tasks with similarity matching. | |
| ### Training Data | |
| Please refer to the paper for the exact process. | |
| ## Inference | |
| Encoding with our PixelLinguist class is very straightforward, just like using a SentenceTransformer class. | |
| ```python | |
| model_name = "AnonymousPage/checkpoint-all" | |
| model = PixelLinguist(model_name) | |
| texts = ["I love you","I like you"] | |
| embeddings = model.encode(texts) | |
| print(outputs[0] @ outputs[1].T) # just use dot product because the embeddings are normalized automatically in the model class. | |
| #tensor(0.9217) | |
| ``` | |
| To use the PixelLinguist class: First install the package following our Github Repo. Then define our PixelLinguist Class as follow. | |
| ```python | |
| import torch | |
| from PIL import Image | |
| from pixel import ( | |
| AutoConfig, | |
| PangoCairoTextRenderer, | |
| PIXELForSequenceClassification, | |
| PIXELForRepresentation, | |
| PoolingMode, | |
| get_attention_mask, | |
| get_transforms, | |
| glue_strip_spaces, | |
| resize_model_embeddings, | |
| ) | |
| from tqdm import tqdm | |
| class PixelLinguist: | |
| def __init__(self, model_name, batch_size = 16, max_seq_length = 64, | |
| device=None, pooling = "mean", keep_mlp = False): | |
| if device is not None: | |
| self.device = device | |
| else: | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.config = AutoConfig.from_pretrained(model_name, num_labels=0) | |
| self.batch_size = batch_size | |
| if keep_mlp == True: | |
| self.model = PIXELForSequenceClassification.from_pretrained( | |
| model_name, | |
| config=self.config, | |
| pooling_mode=PoolingMode.from_string(pooling), | |
| add_layer_norm=True | |
| ).to(self.device) | |
| else: | |
| self.model = PIXELForRepresentation.from_pretrained( | |
| model_name, | |
| config=self.config, | |
| pooling_mode=PoolingMode.from_string(pooling), | |
| add_layer_norm=True | |
| ).to(self.device) | |
| self.processor = PangoCairoTextRenderer.from_pretrained(model_name, rgb=False) | |
| self.processor.max_seq_length = max_seq_length | |
| resize_model_embeddings(self.model, self.processor.max_seq_length) | |
| self.transforms = get_transforms(do_resize=True, size=(self.processor.pixels_per_patch, self.processor.pixels_per_patch * self.processor.max_seq_length)) | |
| def preprocess(self, texts): | |
| encodings = [self.processor(text=glue_strip_spaces(a)) for a in texts] | |
| pixel_values = torch.stack([self.transforms(Image.fromarray(e.pixel_values)) for e in encodings]) | |
| attention_mask = torch.stack([get_attention_mask(e.num_text_patches, seq_length=self.processor.max_seq_length) for e in encodings]) | |
| return {'pixel_values': pixel_values, 'attention_mask': attention_mask} | |
| def encode(self, texts, **kwargs): | |
| all_outputs = [] | |
| for i in tqdm(range(0, len(texts), self.batch_size)): | |
| batch_texts = texts[i:i+batch_size] | |
| inputs = self.preprocess(batch_texts) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs).logits.detach().cpu() | |
| all_outputs.append(outputs) | |
| return torch.cat(all_outputs, dim=0) | |
| ``` | |
| ### Evaluation | |
| For STS evaluation (see Github repo): | |
| ``` | |
| python tools/evaluation_sts_all.py | |
| ``` | |
| For BEIR information retrieval evaluation (see Github repo): | |
| ``` | |
| python tools/evaluation_retrieval.py | |
| ``` |