Text Ranking
Elias Wendt commited on
Commit
a4dc17a
·
1 Parent(s): d81a515

Align sample code with github readme

Browse files
Files changed (1) hide show
  1. README.md +24 -39
README.md CHANGED
@@ -121,45 +121,30 @@ Use the code below to get started with the model.
121
 
122
  ```python
123
  from utils.regression_head import RegressionHead
124
- from transformers import AutoModel, AutoTokenizer
125
- from torch import bfloat16, no_grad, nn
126
- from huggingface_hub import hf_hub_download
127
-
128
-
129
- embedding_model = AutoModel.from_pretrained(
130
- 'Snowflake/snowflake-arctic-embed-m-v2.0',
131
- trust_remote_code=True,
132
- torch_dtype=bfloat16,
133
- unpad_inputs=True,
134
- add_pooling_layer=False,
135
- ).cuda()
136
-
137
- file_path = hf_hub_download(
138
- repo_id="Jackal-AI/JQL-Edu-Heads",
139
- filename="checkpoints/edu-mistral-snowflake-balanced.ckpt",
140
- repo_type="model"
141
- )
142
-
143
- regression_head = RegressionHead.load_from_checkpoint(
144
- file_path,
145
- map_location='cuda'
146
- ).to(bfloat16)
147
-
148
-
149
- tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m-v2.0')
150
- tokens = tokenizer(
151
- text=["Rome was founded by Romulus and Remus in 753 BC.", "Huggingface is cool!"],
152
- padding='longest',
153
- truncation=True,
154
- return_tensors='pt'
155
- ).to('cuda')
156
-
157
- with no_grad():
158
- cls_token = embedding_model(**tokens).last_hidden_state[:, 0]
159
- normalized_cls_token = nn.functional.normalize(cls_token, p=2, dim=1)
160
- predicted_edu_score = regression_head(normalized_cls_token).squeeze(-1)
161
-
162
- print('predicted edu scores:', predicted_edu_score)
163
  ```
164
  <!--
165
  ## Training Details
 
121
 
122
  ```python
123
  from utils.regression_head import RegressionHead
124
+ from transformers.utils.hub import cached_file
125
+ from utils.embedder import get_embedder_instance
126
+ import torch
127
+
128
+ # load embedder
129
+ device = 'cuda'
130
+ embedder = get_embedder_instance('Snowflake/snowflake-arctic-embed-m-v2.0', device, torch.bfloat16)
131
+ # load JQL Edu annotation heads
132
+ regression_head_checkpoints = {
133
+ 'Edu-JQL-Gemma-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-gemma-snowflake-balanced.ckpt'),
134
+ 'Edu-JQL-Mistral-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-mistral-snowflake-balanced.ckpt'),
135
+ 'Edu-JQL-Llama-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-llama-snowflake-balanced.ckpt'),
136
+ }
137
+ regression_heads = {}
138
+ for name, path in regression_head_checkpoints.items():
139
+ regression_heads[name] = RegressionHead.load_from_checkpoint(path, map_location=device).to(torch.bfloat16)
140
+
141
+ # Given a single document
142
+ doc = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua'
143
+ embeddings = embedder.embed([doc])
144
+ scores = {}
145
+ with torch.no_grad():
146
+ for name, regression_head in regression_heads.items():
147
+ scores[f'score_{name}'] = regression_head(embeddings).cpu().squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ```
149
  <!--
150
  ## Training Details