Update README.md
Browse files
README.md
CHANGED
|
@@ -17,6 +17,7 @@ import joblib
|
|
| 17 |
!huggingface-cli login
|
| 18 |
import pandas as pd
|
| 19 |
import torch
|
|
|
|
| 20 |
import torchvision
|
| 21 |
from torchvision import transforms, utils
|
| 22 |
import torch.nn as nn
|
|
@@ -200,46 +201,68 @@ X_test = consistency_checks(X_test, 'title') </pre>
|
|
| 200 |
|
| 201 |
|
| 202 |
<pre>
|
| 203 |
-
def get_embeddings(text_all, tokenizer, model, max_len
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
with torch.no_grad():
|
| 223 |
-
model_output = model(**model_input_token)
|
| 224 |
-
cls_embedding = model_output.last_hidden_state[:, 0, :]
|
| 225 |
-
cls_embedding = cls_embedding.squeeze().numpy()
|
| 226 |
-
embeddings.append(cls_embedding)
|
| 227 |
return embeddings </pre>
|
| 228 |
|
| 229 |
|
| 230 |
-
#
|
| 231 |
-
<pre>
|
| 232 |
-
|
| 233 |
-
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
#
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
#this may take awhile to run
|
| 242 |
-
X_test_embeddings_DBERT = get_embeddings(X_test, tokenizer_DBERT, transformer_model_DBERT, max_len = max_len)
|
| 243 |
|
| 244 |
prediction = model.predict(X_test_embeddings_DBERT)
|
| 245 |
</pre>
|
|
|
|
| 17 |
!huggingface-cli login
|
| 18 |
import pandas as pd
|
| 19 |
import torch
|
| 20 |
+
from transformers import AutoTokenizer, AutoModel
|
| 21 |
import torchvision
|
| 22 |
from torchvision import transforms, utils
|
| 23 |
import torch.nn as nn
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
<pre>
|
| 204 |
+
def get_embeddings(text_all, tokenizer, model, device, max_len=128):
|
| 205 |
+
'''
|
| 206 |
+
Generate embeddings using a transformer model on GPU if available.
|
| 207 |
+
Args:
|
| 208 |
+
- text_all: List of input texts
|
| 209 |
+
- tokenizer: Tokenizer for the model
|
| 210 |
+
- model: Transformer model
|
| 211 |
+
- device: torch.device to run the computations
|
| 212 |
+
- max_len: Maximum token length for the input
|
| 213 |
+
Returns:
|
| 214 |
+
- embeddings: List of embeddings for each input text
|
| 215 |
+
'''
|
| 216 |
+
embeddings = []
|
| 217 |
+
|
| 218 |
+
count = 0
|
| 219 |
+
print('Start embeddings:')
|
| 220 |
+
|
| 221 |
+
for text in text_all:
|
| 222 |
+
count += 1
|
| 223 |
+
if count % (len(text_all) // 10) == 0:
|
| 224 |
+
print(f'{count / len(text_all) * 100:.1f}% done ...')
|
| 225 |
+
|
| 226 |
+
# Tokenize the input text
|
| 227 |
+
model_input_token = tokenizer(
|
| 228 |
+
text,
|
| 229 |
+
add_special_tokens=True,
|
| 230 |
+
max_length=max_len,
|
| 231 |
+
padding='max_length',
|
| 232 |
+
truncation=True,
|
| 233 |
+
return_tensors='pt'
|
| 234 |
+
).to(device) # Move input tensors to GPU
|
| 235 |
+
|
| 236 |
+
# Generate embeddings without gradient computation
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
model_output = model(**model_input_token)
|
| 239 |
+
cls_embedding = model_output.last_hidden_state[:, 0, :] # Use CLS token embedding
|
| 240 |
+
cls_embedding = cls_embedding.squeeze().cpu().numpy() # Move back to CPU for numpy
|
| 241 |
+
embeddings.append(cls_embedding)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
return embeddings </pre>
|
| 244 |
|
| 245 |
|
| 246 |
+
# Check for GPU availability
|
| 247 |
+
<pre>
|
| 248 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 249 |
+
print(f'Using device: {device}')
|
| 250 |
|
| 251 |
+
# Load the tokenizer and model for 'all-mpnet-base-v2'
|
| 252 |
+
print("Loading model and tokenizer...")
|
| 253 |
+
# Load model and tokenizer
|
| 254 |
+
tokenizer_news = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
| 255 |
+
model_news = AutoModel.from_pretrained('distilbert-base-uncased').to(device)
|
| 256 |
|
| 257 |
+
# Set the model to evaluation mode
|
| 258 |
+
model_news.eval()
|
| 259 |
+
|
| 260 |
+
############################################# DBERT UNCASED Embedding #############################################
|
| 261 |
+
############################################# Embedding #############################################
|
| 262 |
+
print("Computing DBERT embeddings for training data...")
|
| 263 |
+
X_test_embeddings_DBERT = get_embeddings(X_test, tokenizer_news, model_news, device, max_len=128)
|
| 264 |
+
print("DBERT embeddings for training data computed!")
|
| 265 |
|
|
|
|
|
|
|
| 266 |
|
| 267 |
prediction = model.predict(X_test_embeddings_DBERT)
|
| 268 |
</pre>
|