GenerTeam commited on
Commit
29fed02
·
verified ·
1 Parent(s): 078c762

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +31 -23
README.md CHANGED
@@ -101,24 +101,27 @@ print(decoded_sequences)
101
  import torch
102
  from transformers import AutoTokenizer, AutoModelForCausalLM
103
 
104
- # Load the tokenizer and model.
105
  tokenizer = AutoTokenizer.from_pretrained("GENERator-v2-prokaryote-1.2b-base", trust_remote_code=True)
106
  model = AutoModelForCausalLM.from_pretrained("GENERator-v2-prokaryote-1.2b-base")
107
 
 
108
  config = model.config
109
  max_length = config.max_position_embeddings
110
 
111
- # Define input sequences.
112
  sequences = [
113
  "ATGAGGTGGCAAGAAATGGGCTAC",
114
  "GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
115
  ]
116
 
117
- # Tokenize the sequences with add_special_tokens=True to automatically add special tokens,
118
- # such as the BOS EOS token, at the appropriate positions.
 
 
119
  tokenizer.padding_side = "right"
120
  inputs = tokenizer(
121
- sequences,
122
  add_special_tokens=True,
123
  return_tensors="pt",
124
  padding=True,
@@ -126,29 +129,34 @@ inputs = tokenizer(
126
  max_length=max_length
127
  )
128
 
129
- # Perform a forward pass through the model to obtain the outputs, including hidden states.
130
  with torch.inference_mode():
131
  outputs = model(**inputs, output_hidden_states=True)
132
 
133
- # Retrieve the hidden states from the last layer.
134
- hidden_states = outputs.hidden_states[-1] # Shape: (batch_size, sequence_length, hidden_size)
135
-
136
- # Use the attention_mask to determine the index of the last token in each sequence.
137
- # Since add_special_tokens=True is used, the last token is typically the EOS token.
138
  attention_mask = inputs["attention_mask"]
139
- last_token_indices = attention_mask.sum(dim=1) - 1 # Index of the last token for each sequence
140
-
141
- # Extract the embedding corresponding to the EOS token for each sequence.
142
- seq_embeddings = []
143
- for i, token_index in enumerate(last_token_indices):
144
- # Fetch the embedding for the last token (EOS token).
145
- seq_embedding = hidden_states[i, token_index, :]
146
- seq_embeddings.append(seq_embedding)
147
-
148
- # Stack the embeddings into a tensor with shape (batch_size, hidden_size)
149
- seq_embeddings = torch.stack(seq_embeddings)
150
 
151
- print("Sequence Embeddings:", seq_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  ```
154
 
 
101
  import torch
102
  from transformers import AutoTokenizer, AutoModelForCausalLM
103
 
104
+ # Load the tokenizer and model
105
  tokenizer = AutoTokenizer.from_pretrained("GENERator-v2-prokaryote-1.2b-base", trust_remote_code=True)
106
  model = AutoModelForCausalLM.from_pretrained("GENERator-v2-prokaryote-1.2b-base")
107
 
108
+ # Get model configuration
109
  config = model.config
110
  max_length = config.max_position_embeddings
111
 
112
+ # Define input sequences
113
  sequences = [
114
  "ATGAGGTGGCAAGAAATGGGCTAC",
115
  "GAATTCCATGAGGCTATAGAATAATCTAAGAGAAAT"
116
  ]
117
 
118
+ # Truncate each sequence to the nearest multiple of 6
119
+ processed_sequences = [tokenizer.bos_token + seq[:len(seq)//6*6] for seq in sequences]
120
+
121
+ # Tokenization
122
  tokenizer.padding_side = "right"
123
  inputs = tokenizer(
124
+ processed_sequences,
125
  add_special_tokens=True,
126
  return_tensors="pt",
127
  padding=True,
 
129
  max_length=max_length
130
  )
131
 
132
+ # Model Inference
133
  with torch.inference_mode():
134
  outputs = model(**inputs, output_hidden_states=True)
135
 
136
+ hidden_states = outputs.hidden_states[-1]
 
 
 
 
137
  attention_mask = inputs["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # Option 1: Last token (EOS) embedding
140
+ last_token_indices = attention_mask.sum(dim=1) - 1
141
+ eos_embeddings = hidden_states[torch.arange(hidden_states.size(0)), last_token_indices, :]
142
+
143
+ # Option 2: Mean pooling over all tokens
144
+ expanded_mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).to(torch.float32)
145
+ sum_embeddings = torch.sum(hidden_states * expanded_mask, dim=1)
146
+ mean_embeddings = sum_embeddings / expanded_mask.sum(dim=1)
147
+
148
+ # Output
149
+ print("EOS (Last Token) Embeddings:", eos_embeddings)
150
+ print("Mean Pooling Embeddings:", mean_embeddings)
151
+
152
+ # ============================================================================
153
+ # Additional notes:
154
+ # - The preprocessing step ensures sequences are multiples of 6 for 6-mer tokenizer
155
+ # - For causal LM, the last token embedding (EOS) is commonly used
156
+ # - Mean pooling considers all tokens including BOS and content tokens
157
+ # - The choice depends on your downstream task requirements
158
+ # - Both methods handle variable sequence lengths via attention mask
159
+ # ============================================================================
160
 
161
  ```
162