import logging from typing import Union, List, Optional, Dict, Any, Literal import torch import torch.nn.functional as F from transformers import AutoTokenizer import transformers from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig import time import math import concurrent.futures def padding_ceiling(n): if n <= 0: return 1 elif n & (n - 1) == 0: # Check if n is already a power of 2 return n else: return 2 ** math.ceil(math.log2(n)) class MyStreamer(transformers.generation.streamers.BaseStreamer): def __init__(self) -> None: self.reset() def reset(self): self.token_latencies = [] self.iter = 0 self.now = time.time() def put(self, tokens): now = time.time() token_latency = now - self.now self.now = now self.iter += 1 self.token_latencies.append(token_latency) def end(self): print("\n\n") print("First 5 token latencies:", self.token_latencies[:5]) print("All token latencies:", sum(self.token_latencies[:])) class MistralModel: """ A class for generating text using the Mistral language model. """ def __init__(self, model_name): self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS, quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')) # self.model_name = 'mistralai/Mistral-7B-Instruct-v0.2' self.model_name = model_name self.amp: Literal['bf16', 'fp32'] = 'bf16' self.batch_size = 1 self.tp_degree = 2 self.n_positions = 4096 self.context_length_estimate = [2289, 4096] # self.context_length_estimate = 2289 self.model = self._load_model() self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.prompt_template = "[INST] {prompt} [/INST]" def _load_model(self) -> MistralForSampling: """ Load and initialize the Mistral model. Returns: MistralForSampling: The initialized Mistral model. """ model = MistralForSampling.from_pretrained( self.model_name, amp=self.amp, batch_size=self.batch_size, tp_degree=self.tp_degree, n_positions=self.n_positions, neuron_config=self.neuron_config, context_length_estimate=self.context_length_estimate, # compiler_args=["--model-type=transformer", "--target=inf2", "--auto-cast=all", "--auto-cast-type=fp8_e4m3", "--optlevel=3", "--enable-saturate-infinity"] ) model.to_neuron() return model def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str: """ Generate text using the Mistral model. Args: inputs (Union[str, List[int]]): The input prompt or a list of input embeddings. parameters (Optional[Dict[str, Any]]): Optional parameters for text generation. Returns: str: The generated text. Raises: ValueError: If the input type is invalid. """ try: max_new_tokens = parameters.get("max_new_tokens", 256) top_k = parameters.get("top_k", 100) top_p = parameters.get("top_p", 0.1) temperature = parameters.get("temperature", 0.1) no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3) print( f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}") if isinstance(inputs, str): generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature, no_repeat_ngram_size) elif isinstance(inputs, list): generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature, no_repeat_ngram_size) else: raise ValueError("Invalid input type. Must be str or List[int]") return generated_text except Exception as e: logging.error(f"Error generating text: {e}") raise def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float, no_repeat_ngram_size: int) -> str: """ Generate text from a given prompt using the Mistral model. Args: prompt (str): The input prompt. max_new_tokens (int): The maximum number of new tokens to generate. Returns: str: The generated text. """ input_prompt = self.prompt_template.format(prompt=prompt) encoded_input = self.tokenizer(input_prompt, return_tensors='pt') input_ids = encoded_input.input_ids with torch.inference_mode(): generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions, input_ids.shape[1] + max_new_tokens), start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size) decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence] generated_text = decoded_output[0].split('[/INST]')[1].strip("").strip() return generated_text def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float, temperature: float, no_repeat_ngram_size: int) -> str: """ Generate text from a given list of input embeddings using the Mistral model. Args: input_embeddings (List[int]): A list of input embeddings. max_new_tokens (int): The maximum number of new tokens to generate. Returns: str: The generated text. """ s1 = time.time() input_embeds_tensor = torch.tensor(input_embeddings) input_embeds_length = input_embeds_tensor.shape[1] padding_size = padding_ceiling(input_embeds_length) if padding_size >= self.n_positions: padding_size = input_embeds_length padded_input_embeds = input_embeds_tensor else: padding_gap = padding_size - input_embeds_length padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id) print("ms1 - input_embeds time: ", time.time() - s1) s2 = time.time() with torch.inference_mode(): generated_sequence = self.model.sample(padded_input_embeds, sequence_length=min(self.n_positions, padding_size + max_new_tokens), start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer()) with concurrent.futures.ThreadPoolExecutor() as executor: decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence)) # decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence] print("ms2 - decoded_output time: ", time.time() - s2) generated_text = decoded_output[0].strip("").strip() return generated_text