|
|
import logging |
|
|
from typing import Union, List, Optional, Dict, Any, Literal |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer |
|
|
import transformers |
|
|
from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig |
|
|
import time |
|
|
import math |
|
|
import concurrent.futures |
|
|
|
|
|
|
|
|
def padding_ceiling(n): |
|
|
if n <= 0: |
|
|
return 1 |
|
|
elif n & (n - 1) == 0: |
|
|
return n |
|
|
else: |
|
|
return 2 ** math.ceil(math.log2(n)) |
|
|
|
|
|
|
|
|
class MyStreamer(transformers.generation.streamers.BaseStreamer): |
|
|
def __init__(self) -> None: |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.token_latencies = [] |
|
|
self.iter = 0 |
|
|
self.now = time.time() |
|
|
|
|
|
def put(self, tokens): |
|
|
now = time.time() |
|
|
token_latency = now - self.now |
|
|
self.now = now |
|
|
self.iter += 1 |
|
|
self.token_latencies.append(token_latency) |
|
|
|
|
|
def end(self): |
|
|
print("\n\n") |
|
|
print("First 5 token latencies:", self.token_latencies[:5]) |
|
|
print("All token latencies:", sum(self.token_latencies[:])) |
|
|
|
|
|
|
|
|
class MistralModel: |
|
|
""" |
|
|
A class for generating text using the Mistral language model. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name): |
|
|
self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS, |
|
|
quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')) |
|
|
|
|
|
self.model_name = model_name |
|
|
self.amp: Literal['bf16', 'fp32'] = 'bf16' |
|
|
self.batch_size = 1 |
|
|
self.tp_degree = 2 |
|
|
self.n_positions = 4096 |
|
|
self.context_length_estimate = [2289, 4096] |
|
|
|
|
|
|
|
|
self.model = self._load_model() |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.prompt_template = "<s>[INST] {prompt} [/INST]" |
|
|
|
|
|
def _load_model(self) -> MistralForSampling: |
|
|
""" |
|
|
Load and initialize the Mistral model. |
|
|
|
|
|
Returns: |
|
|
MistralForSampling: The initialized Mistral model. |
|
|
""" |
|
|
model = MistralForSampling.from_pretrained( |
|
|
self.model_name, |
|
|
amp=self.amp, |
|
|
batch_size=self.batch_size, |
|
|
tp_degree=self.tp_degree, |
|
|
n_positions=self.n_positions, |
|
|
neuron_config=self.neuron_config, |
|
|
context_length_estimate=self.context_length_estimate, |
|
|
|
|
|
) |
|
|
model.to_neuron() |
|
|
return model |
|
|
|
|
|
def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str: |
|
|
""" |
|
|
Generate text using the Mistral model. |
|
|
|
|
|
Args: |
|
|
inputs (Union[str, List[int]]): The input prompt or a list of input embeddings. |
|
|
parameters (Optional[Dict[str, Any]]): Optional parameters for text generation. |
|
|
|
|
|
Returns: |
|
|
str: The generated text. |
|
|
|
|
|
Raises: |
|
|
ValueError: If the input type is invalid. |
|
|
""" |
|
|
try: |
|
|
max_new_tokens = parameters.get("max_new_tokens", 256) |
|
|
top_k = parameters.get("top_k", 100) |
|
|
top_p = parameters.get("top_p", 0.1) |
|
|
temperature = parameters.get("temperature", 0.1) |
|
|
no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3) |
|
|
print( |
|
|
f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}") |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature, |
|
|
no_repeat_ngram_size) |
|
|
elif isinstance(inputs, list): |
|
|
generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature, |
|
|
no_repeat_ngram_size) |
|
|
else: |
|
|
raise ValueError("Invalid input type. Must be str or List[int]") |
|
|
|
|
|
return generated_text |
|
|
except Exception as e: |
|
|
logging.error(f"Error generating text: {e}") |
|
|
raise |
|
|
|
|
|
def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float, |
|
|
no_repeat_ngram_size: int) -> str: |
|
|
""" |
|
|
Generate text from a given prompt using the Mistral model. |
|
|
|
|
|
Args: |
|
|
prompt (str): The input prompt. |
|
|
max_new_tokens (int): The maximum number of new tokens to generate. |
|
|
|
|
|
Returns: |
|
|
str: The generated text. |
|
|
""" |
|
|
input_prompt = self.prompt_template.format(prompt=prompt) |
|
|
encoded_input = self.tokenizer(input_prompt, return_tensors='pt') |
|
|
input_ids = encoded_input.input_ids |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions, |
|
|
input_ids.shape[1] + max_new_tokens), |
|
|
start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
|
|
no_repeat_ngram_size=no_repeat_ngram_size) |
|
|
decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence] |
|
|
|
|
|
generated_text = decoded_output[0].split('[/INST]')[1].strip("</s>").strip() |
|
|
return generated_text |
|
|
|
|
|
def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float, |
|
|
temperature: float, no_repeat_ngram_size: int) -> str: |
|
|
""" |
|
|
Generate text from a given list of input embeddings using the Mistral model. |
|
|
|
|
|
Args: |
|
|
input_embeddings (List[int]): A list of input embeddings. |
|
|
max_new_tokens (int): The maximum number of new tokens to generate. |
|
|
|
|
|
Returns: |
|
|
str: The generated text. |
|
|
""" |
|
|
s1 = time.time() |
|
|
input_embeds_tensor = torch.tensor(input_embeddings) |
|
|
input_embeds_length = input_embeds_tensor.shape[1] |
|
|
padding_size = padding_ceiling(input_embeds_length) |
|
|
if padding_size >= self.n_positions: |
|
|
padding_size = input_embeds_length |
|
|
padded_input_embeds = input_embeds_tensor |
|
|
else: |
|
|
padding_gap = padding_size - input_embeds_length |
|
|
padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id) |
|
|
print("ms1 - input_embeds time: ", time.time() - s1) |
|
|
|
|
|
s2 = time.time() |
|
|
with torch.inference_mode(): |
|
|
generated_sequence = self.model.sample(padded_input_embeds, |
|
|
sequence_length=min(self.n_positions, padding_size + max_new_tokens), |
|
|
start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
|
|
no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer()) |
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence)) |
|
|
|
|
|
print("ms2 - decoded_output time: ", time.time() - s2) |
|
|
|
|
|
generated_text = decoded_output[0].strip("</s>").strip() |
|
|
return generated_text |
|
|
|
|
|
|