Juan Pablo Balarini commited on
Commit
05ff23c
·
1 Parent(s): a6bec27

Update README with usage instructions

Browse files
Files changed (1) hide show
  1. README.md +74 -2
README.md CHANGED
@@ -28,9 +28,81 @@ Compared to multi-vector (ColBERT-like) architectures, eager-embed-v1 offers a s
28
 
29
  ## How to Get Started with the Model
30
 
31
- **TODO: Coming soon**. Check [here for now](https://github.com/eagerworks/eager-embed/blob/main/inference.py) for now.
32
  ```python
33
- TODO: Add inference code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ```
35
 
36
  ## Training Details
 
28
 
29
  ## How to Get Started with the Model
30
 
31
+ Load the model and define encode helper function
32
  ```python
33
+ import torch
34
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
35
+ from transformers.utils.import_utils import is_flash_attn_2_available
36
+ from qwen_vl_utils import process_vision_info
37
+
38
+ MODEL_NAME = "eagerworks/eager-embed-v1"
39
+ DEVICE = torch.device("cpu")
40
+ if torch.cuda.is_available():
41
+ DEVICE = torch.device("cuda:0")
42
+ elif torch.backends.mps.is_available():
43
+ DEVICE = torch.device("mps")
44
+ DTYPE = torch.bfloat16
45
+
46
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
47
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
48
+ MODEL_NAME,
49
+ attn_implementation=(
50
+ "flash_attention_2" if is_flash_attn_2_available() else None
51
+ ),
52
+ dtype=DTYPE
53
+ ).to(DEVICE).eval()
54
+
55
+ # Function to Encode Message
56
+ def encode_message(message):
57
+ with torch.no_grad():
58
+ texts = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
59
+ image_inputs, video_inputs = process_vision_info(message)
60
+
61
+ inputs = processor(
62
+ text=texts,
63
+ images=image_inputs,
64
+ videos=video_inputs,
65
+ return_tensors="pt",
66
+ padding="longest",
67
+ ).to(DEVICE)
68
+
69
+ model_outputs = model(**inputs, return_dict=True, output_hidden_states=True)
70
+
71
+ last_hidden_state = model_outputs.hidden_states[-1]
72
+ embeddings = last_hidden_state[:, -1]
73
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
74
+ return embeddings
75
+ ```
76
+
77
+ 🌍 Multilingual Text Retrieval
78
+ ```python
79
+ example_query = "Query: What is the capital city of Uruguay?"
80
+ example_text_1 = "Montevideo es la capital y la ciudad más poblada de la República Oriental del Uruguay, así como la capital del departamento homónimo"
81
+ example_text_2 = "El río Uruguay es un río internacional que forma parte de la cuenca del Plata. Nace en Brasil, recorre unos 1.800 km y desemboca en el Río de la Plata"
82
+ query = [{'role': 'user', 'content': [{'type': 'text', 'text': example_query}]}]
83
+ text_1 = [{'role': 'user', 'content': [{'type': 'text', 'text': example_text_1}]}]
84
+ text_2 = [{'role': 'user', 'content': [{'type': 'text', 'text': example_text_2}]}]
85
+
86
+ sim1 = torch.cosine_similarity(encode_message(query), encode_message(text_1))
87
+ sim2 = torch.cosine_similarity(encode_message(query), encode_message(text_2))
88
+
89
+ print("Similarities:", sim1.item(), sim2.item())
90
+ ```
91
+
92
+ 📈 Image Document Retrieval (Image, Chart, PDF)
93
+ ```python
94
+ MAX_IMAGE_SIZE = 784
95
+ example_query = 'Query: Where can we find the animal llama?'
96
+ example_image_1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png"
97
+ example_image_2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png"
98
+ query = [{'role': 'user', 'content': [{'type': 'text', 'text': example_query}]}]
99
+ image_1 = [{'role': 'user', 'content': [{'type': 'image', 'image': example_image_1, 'resized_height': MAX_IMAGE_SIZE, 'resized_width': MAX_IMAGE_SIZE}]}]
100
+ image_2 = [{'role': 'user', 'content': [{'type': 'image', 'image': example_image_2, 'resized_height': MAX_IMAGE_SIZE, 'resized_width': MAX_IMAGE_SIZE}]}]
101
+
102
+ sim1 = torch.cosine_similarity(encode_message(query), encode_message(image_1))
103
+ sim2 = torch.cosine_similarity(encode_message(query), encode_message(image_2))
104
+
105
+ print("Similarities:", sim1.item(), sim2.item())
106
  ```
107
 
108
  ## Training Details