| | import logging |
| | from typing import Union, List, Optional, Dict, Any, Literal |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import AutoTokenizer |
| | import transformers |
| | from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig |
| | import time |
| | import math |
| | import concurrent.futures |
| |
|
| |
|
| | def padding_ceiling(n): |
| | if n <= 0: |
| | return 1 |
| | elif n & (n - 1) == 0: |
| | return n |
| | else: |
| | return 2 ** math.ceil(math.log2(n)) |
| |
|
| |
|
| | class MyStreamer(transformers.generation.streamers.BaseStreamer): |
| | def __init__(self) -> None: |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.token_latencies = [] |
| | self.iter = 0 |
| | self.now = time.time() |
| |
|
| | def put(self, tokens): |
| | now = time.time() |
| | token_latency = now - self.now |
| | self.now = now |
| | self.iter += 1 |
| | self.token_latencies.append(token_latency) |
| |
|
| | def end(self): |
| | print("\n\n") |
| | print("First 5 token latencies:", self.token_latencies[:5]) |
| | print("All token latencies:", sum(self.token_latencies[:])) |
| |
|
| |
|
| | class MistralModel: |
| | """ |
| | A class for generating text using the Mistral language model. |
| | """ |
| |
|
| | def __init__(self, model_name): |
| | self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS, |
| | quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')) |
| | |
| | self.model_name = model_name |
| | self.amp: Literal['bf16', 'fp32'] = 'bf16' |
| | self.batch_size = 1 |
| | self.tp_degree = 2 |
| | self.n_positions = 4096 |
| | self.context_length_estimate = [2289, 4096] |
| | |
| |
|
| | self.model = self._load_model() |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.prompt_template = "<s>[INST] {prompt} [/INST]" |
| |
|
| | def _load_model(self) -> MistralForSampling: |
| | """ |
| | Load and initialize the Mistral model. |
| | |
| | Returns: |
| | MistralForSampling: The initialized Mistral model. |
| | """ |
| | model = MistralForSampling.from_pretrained( |
| | self.model_name, |
| | amp=self.amp, |
| | batch_size=self.batch_size, |
| | tp_degree=self.tp_degree, |
| | n_positions=self.n_positions, |
| | neuron_config=self.neuron_config, |
| | context_length_estimate=self.context_length_estimate, |
| | |
| | ) |
| | model.to_neuron() |
| | return model |
| |
|
| | def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str: |
| | """ |
| | Generate text using the Mistral model. |
| | |
| | Args: |
| | inputs (Union[str, List[int]]): The input prompt or a list of input embeddings. |
| | parameters (Optional[Dict[str, Any]]): Optional parameters for text generation. |
| | |
| | Returns: |
| | str: The generated text. |
| | |
| | Raises: |
| | ValueError: If the input type is invalid. |
| | """ |
| | try: |
| | max_new_tokens = parameters.get("max_new_tokens", 256) |
| | top_k = parameters.get("top_k", 100) |
| | top_p = parameters.get("top_p", 0.1) |
| | temperature = parameters.get("temperature", 0.1) |
| | no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3) |
| | print( |
| | f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}") |
| |
|
| | if isinstance(inputs, str): |
| | generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature, |
| | no_repeat_ngram_size) |
| | elif isinstance(inputs, list): |
| | generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature, |
| | no_repeat_ngram_size) |
| | else: |
| | raise ValueError("Invalid input type. Must be str or List[int]") |
| |
|
| | return generated_text |
| | except Exception as e: |
| | logging.error(f"Error generating text: {e}") |
| | raise |
| |
|
| | def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float, |
| | no_repeat_ngram_size: int) -> str: |
| | """ |
| | Generate text from a given prompt using the Mistral model. |
| | |
| | Args: |
| | prompt (str): The input prompt. |
| | max_new_tokens (int): The maximum number of new tokens to generate. |
| | |
| | Returns: |
| | str: The generated text. |
| | """ |
| | input_prompt = self.prompt_template.format(prompt=prompt) |
| | encoded_input = self.tokenizer(input_prompt, return_tensors='pt') |
| | input_ids = encoded_input.input_ids |
| |
|
| | with torch.inference_mode(): |
| | generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions, |
| | input_ids.shape[1] + max_new_tokens), |
| | start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
| | no_repeat_ngram_size=no_repeat_ngram_size) |
| | decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence] |
| |
|
| | generated_text = decoded_output[0].split('[/INST]')[1].strip("</s>").strip() |
| | return generated_text |
| |
|
| | def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float, |
| | temperature: float, no_repeat_ngram_size: int) -> str: |
| | """ |
| | Generate text from a given list of input embeddings using the Mistral model. |
| | |
| | Args: |
| | input_embeddings (List[int]): A list of input embeddings. |
| | max_new_tokens (int): The maximum number of new tokens to generate. |
| | |
| | Returns: |
| | str: The generated text. |
| | """ |
| | s1 = time.time() |
| | input_embeds_tensor = torch.tensor(input_embeddings) |
| | input_embeds_length = input_embeds_tensor.shape[1] |
| | padding_size = padding_ceiling(input_embeds_length) |
| | if padding_size >= self.n_positions: |
| | padding_size = input_embeds_length |
| | padded_input_embeds = input_embeds_tensor |
| | else: |
| | padding_gap = padding_size - input_embeds_length |
| | padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id) |
| | print("ms1 - input_embeds time: ", time.time() - s1) |
| |
|
| | s2 = time.time() |
| | with torch.inference_mode(): |
| | generated_sequence = self.model.sample(padded_input_embeds, |
| | sequence_length=min(self.n_positions, padding_size + max_new_tokens), |
| | start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
| | no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer()) |
| | with concurrent.futures.ThreadPoolExecutor() as executor: |
| | decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence)) |
| | |
| | print("ms2 - decoded_output time: ", time.time() - s2) |
| |
|
| | generated_text = decoded_output[0].strip("</s>").strip() |
| | return generated_text |
| |
|
| |
|