Update kraken_model/modeling_kraken.py
Browse files
kraken_model/modeling_kraken.py
CHANGED
|
@@ -41,10 +41,6 @@ class KrakenForCausalLM(PreTrainedModel):
|
|
| 41 |
model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
|
| 42 |
return model_keys[model_decision_index]
|
| 43 |
|
| 44 |
-
def expert_tokenizer(self, text):
|
| 45 |
-
model_key = self.determine_model(text)
|
| 46 |
-
return self.tokenizers[model_key]
|
| 47 |
-
|
| 48 |
|
| 49 |
def generate(self, input_ids, **generate_kwargs):
|
| 50 |
# Tokenize the input_ids
|
|
@@ -75,8 +71,17 @@ class KrakenForCausalLM(PreTrainedModel):
|
|
| 75 |
tok_input_ids = tok.input_ids.to(current_device)
|
| 76 |
tok_attention_mask = tok.attention_mask.to(current_device)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
|
|
|
|
| 41 |
model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
|
| 42 |
return model_keys[model_decision_index]
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def generate(self, input_ids, **generate_kwargs):
|
| 46 |
# Tokenize the input_ids
|
|
|
|
| 71 |
tok_input_ids = tok.input_ids.to(current_device)
|
| 72 |
tok_attention_mask = tok.attention_mask.to(current_device)
|
| 73 |
|
| 74 |
+
|
| 75 |
+
# Generate text using the modified model
|
| 76 |
+
output_ids = model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
|
| 77 |
+
|
| 78 |
+
# Decode the output using the expert tokenizer
|
| 79 |
+
decoded_text = self.tokenizers[model_key].decode(output_ids[0], skip_special_tokens=True)
|
| 80 |
+
|
| 81 |
+
# Retokenize the decoded text using the base tokenizer for external compatibility
|
| 82 |
+
retokenized_ids = self.tokenizer(decoded_text, return_tensors="pt").input_ids.to(current_device)
|
| 83 |
+
|
| 84 |
+
return retokenized_ids
|
| 85 |
|
| 86 |
|
| 87 |
|