Juan Pablo Balarini
commited on
Commit
·
05ff23c
1
Parent(s):
a6bec27
Update README with usage instructions
Browse files
README.md
CHANGED
|
@@ -28,9 +28,81 @@ Compared to multi-vector (ColBERT-like) architectures, eager-embed-v1 offers a s
|
|
| 28 |
|
| 29 |
## How to Get Started with the Model
|
| 30 |
|
| 31 |
-
|
| 32 |
```python
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
```
|
| 35 |
|
| 36 |
## Training Details
|
|
|
|
| 28 |
|
| 29 |
## How to Get Started with the Model
|
| 30 |
|
| 31 |
+
Load the model and define encode helper function
|
| 32 |
```python
|
| 33 |
+
import torch
|
| 34 |
+
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 35 |
+
from transformers.utils.import_utils import is_flash_attn_2_available
|
| 36 |
+
from qwen_vl_utils import process_vision_info
|
| 37 |
+
|
| 38 |
+
MODEL_NAME = "eagerworks/eager-embed-v1"
|
| 39 |
+
DEVICE = torch.device("cpu")
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
DEVICE = torch.device("cuda:0")
|
| 42 |
+
elif torch.backends.mps.is_available():
|
| 43 |
+
DEVICE = torch.device("mps")
|
| 44 |
+
DTYPE = torch.bfloat16
|
| 45 |
+
|
| 46 |
+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
| 47 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 48 |
+
MODEL_NAME,
|
| 49 |
+
attn_implementation=(
|
| 50 |
+
"flash_attention_2" if is_flash_attn_2_available() else None
|
| 51 |
+
),
|
| 52 |
+
dtype=DTYPE
|
| 53 |
+
).to(DEVICE).eval()
|
| 54 |
+
|
| 55 |
+
# Function to Encode Message
|
| 56 |
+
def encode_message(message):
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
texts = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
|
| 59 |
+
image_inputs, video_inputs = process_vision_info(message)
|
| 60 |
+
|
| 61 |
+
inputs = processor(
|
| 62 |
+
text=texts,
|
| 63 |
+
images=image_inputs,
|
| 64 |
+
videos=video_inputs,
|
| 65 |
+
return_tensors="pt",
|
| 66 |
+
padding="longest",
|
| 67 |
+
).to(DEVICE)
|
| 68 |
+
|
| 69 |
+
model_outputs = model(**inputs, return_dict=True, output_hidden_states=True)
|
| 70 |
+
|
| 71 |
+
last_hidden_state = model_outputs.hidden_states[-1]
|
| 72 |
+
embeddings = last_hidden_state[:, -1]
|
| 73 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
|
| 74 |
+
return embeddings
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
🌍 Multilingual Text Retrieval
|
| 78 |
+
```python
|
| 79 |
+
example_query = "Query: What is the capital city of Uruguay?"
|
| 80 |
+
example_text_1 = "Montevideo es la capital y la ciudad más poblada de la República Oriental del Uruguay, así como la capital del departamento homónimo"
|
| 81 |
+
example_text_2 = "El río Uruguay es un río internacional que forma parte de la cuenca del Plata. Nace en Brasil, recorre unos 1.800 km y desemboca en el Río de la Plata"
|
| 82 |
+
query = [{'role': 'user', 'content': [{'type': 'text', 'text': example_query}]}]
|
| 83 |
+
text_1 = [{'role': 'user', 'content': [{'type': 'text', 'text': example_text_1}]}]
|
| 84 |
+
text_2 = [{'role': 'user', 'content': [{'type': 'text', 'text': example_text_2}]}]
|
| 85 |
+
|
| 86 |
+
sim1 = torch.cosine_similarity(encode_message(query), encode_message(text_1))
|
| 87 |
+
sim2 = torch.cosine_similarity(encode_message(query), encode_message(text_2))
|
| 88 |
+
|
| 89 |
+
print("Similarities:", sim1.item(), sim2.item())
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
📈 Image Document Retrieval (Image, Chart, PDF)
|
| 93 |
+
```python
|
| 94 |
+
MAX_IMAGE_SIZE = 784
|
| 95 |
+
example_query = 'Query: Where can we find the animal llama?'
|
| 96 |
+
example_image_1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png"
|
| 97 |
+
example_image_2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png"
|
| 98 |
+
query = [{'role': 'user', 'content': [{'type': 'text', 'text': example_query}]}]
|
| 99 |
+
image_1 = [{'role': 'user', 'content': [{'type': 'image', 'image': example_image_1, 'resized_height': MAX_IMAGE_SIZE, 'resized_width': MAX_IMAGE_SIZE}]}]
|
| 100 |
+
image_2 = [{'role': 'user', 'content': [{'type': 'image', 'image': example_image_2, 'resized_height': MAX_IMAGE_SIZE, 'resized_width': MAX_IMAGE_SIZE}]}]
|
| 101 |
+
|
| 102 |
+
sim1 = torch.cosine_similarity(encode_message(query), encode_message(image_1))
|
| 103 |
+
sim2 = torch.cosine_similarity(encode_message(query), encode_message(image_2))
|
| 104 |
+
|
| 105 |
+
print("Similarities:", sim1.item(), sim2.item())
|
| 106 |
```
|
| 107 |
|
| 108 |
## Training Details
|