Elias Wendt commited on
Commit ·
a4dc17a
1
Parent(s): d81a515
Align sample code with github readme
Browse files
README.md
CHANGED
|
@@ -121,45 +121,30 @@ Use the code below to get started with the model.
|
|
| 121 |
|
| 122 |
```python
|
| 123 |
from utils.regression_head import RegressionHead
|
| 124 |
-
from transformers import
|
| 125 |
-
from
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m-v2.0')
|
| 150 |
-
tokens = tokenizer(
|
| 151 |
-
text=["Rome was founded by Romulus and Remus in 753 BC.", "Huggingface is cool!"],
|
| 152 |
-
padding='longest',
|
| 153 |
-
truncation=True,
|
| 154 |
-
return_tensors='pt'
|
| 155 |
-
).to('cuda')
|
| 156 |
-
|
| 157 |
-
with no_grad():
|
| 158 |
-
cls_token = embedding_model(**tokens).last_hidden_state[:, 0]
|
| 159 |
-
normalized_cls_token = nn.functional.normalize(cls_token, p=2, dim=1)
|
| 160 |
-
predicted_edu_score = regression_head(normalized_cls_token).squeeze(-1)
|
| 161 |
-
|
| 162 |
-
print('predicted edu scores:', predicted_edu_score)
|
| 163 |
```
|
| 164 |
<!--
|
| 165 |
## Training Details
|
|
|
|
| 121 |
|
| 122 |
```python
|
| 123 |
from utils.regression_head import RegressionHead
|
| 124 |
+
from transformers.utils.hub import cached_file
|
| 125 |
+
from utils.embedder import get_embedder_instance
|
| 126 |
+
import torch
|
| 127 |
+
|
| 128 |
+
# load embedder
|
| 129 |
+
device = 'cuda'
|
| 130 |
+
embedder = get_embedder_instance('Snowflake/snowflake-arctic-embed-m-v2.0', device, torch.bfloat16)
|
| 131 |
+
# load JQL Edu annotation heads
|
| 132 |
+
regression_head_checkpoints = {
|
| 133 |
+
'Edu-JQL-Gemma-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-gemma-snowflake-balanced.ckpt'),
|
| 134 |
+
'Edu-JQL-Mistral-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-mistral-snowflake-balanced.ckpt'),
|
| 135 |
+
'Edu-JQL-Llama-SF': cached_file('Jackal-AI/JQL-Edu-Heads', 'checkpoints/edu-llama-snowflake-balanced.ckpt'),
|
| 136 |
+
}
|
| 137 |
+
regression_heads = {}
|
| 138 |
+
for name, path in regression_head_checkpoints.items():
|
| 139 |
+
regression_heads[name] = RegressionHead.load_from_checkpoint(path, map_location=device).to(torch.bfloat16)
|
| 140 |
+
|
| 141 |
+
# Given a single document
|
| 142 |
+
doc = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua'
|
| 143 |
+
embeddings = embedder.embed([doc])
|
| 144 |
+
scores = {}
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for name, regression_head in regression_heads.items():
|
| 147 |
+
scores[f'score_{name}'] = regression_head(embeddings).cpu().squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
```
|
| 149 |
<!--
|
| 150 |
## Training Details
|