inf2_dir / app /backend_model.py
root
feat: update
7c5440e
import logging
from typing import Union, List, Optional, Dict, Any, Literal
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
import transformers
from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig
import time
import math
import concurrent.futures
def padding_ceiling(n):
if n <= 0:
return 1
elif n & (n - 1) == 0: # Check if n is already a power of 2
return n
else:
return 2 ** math.ceil(math.log2(n))
class MyStreamer(transformers.generation.streamers.BaseStreamer):
def __init__(self) -> None:
self.reset()
def reset(self):
self.token_latencies = []
self.iter = 0
self.now = time.time()
def put(self, tokens):
now = time.time()
token_latency = now - self.now
self.now = now
self.iter += 1
self.token_latencies.append(token_latency)
def end(self):
print("\n\n")
print("First 5 token latencies:", self.token_latencies[:5])
print("All token latencies:", sum(self.token_latencies[:]))
class MistralModel:
"""
A class for generating text using the Mistral language model.
"""
def __init__(self, model_name):
self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS,
quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16'))
# self.model_name = 'mistralai/Mistral-7B-Instruct-v0.2'
self.model_name = model_name
self.amp: Literal['bf16', 'fp32'] = 'bf16'
self.batch_size = 1
self.tp_degree = 2
self.n_positions = 4096
self.context_length_estimate = [2289, 4096]
# self.context_length_estimate = 2289
self.model = self._load_model()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_template = "<s>[INST] {prompt} [/INST]"
def _load_model(self) -> MistralForSampling:
"""
Load and initialize the Mistral model.
Returns:
MistralForSampling: The initialized Mistral model.
"""
model = MistralForSampling.from_pretrained(
self.model_name,
amp=self.amp,
batch_size=self.batch_size,
tp_degree=self.tp_degree,
n_positions=self.n_positions,
neuron_config=self.neuron_config,
context_length_estimate=self.context_length_estimate,
# compiler_args=["--model-type=transformer", "--target=inf2", "--auto-cast=all", "--auto-cast-type=fp8_e4m3", "--optlevel=3", "--enable-saturate-infinity"]
)
model.to_neuron()
return model
def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str:
"""
Generate text using the Mistral model.
Args:
inputs (Union[str, List[int]]): The input prompt or a list of input embeddings.
parameters (Optional[Dict[str, Any]]): Optional parameters for text generation.
Returns:
str: The generated text.
Raises:
ValueError: If the input type is invalid.
"""
try:
max_new_tokens = parameters.get("max_new_tokens", 256)
top_k = parameters.get("top_k", 100)
top_p = parameters.get("top_p", 0.1)
temperature = parameters.get("temperature", 0.1)
no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3)
print(
f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}")
if isinstance(inputs, str):
generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature,
no_repeat_ngram_size)
elif isinstance(inputs, list):
generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature,
no_repeat_ngram_size)
else:
raise ValueError("Invalid input type. Must be str or List[int]")
return generated_text
except Exception as e:
logging.error(f"Error generating text: {e}")
raise
def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float,
no_repeat_ngram_size: int) -> str:
"""
Generate text from a given prompt using the Mistral model.
Args:
prompt (str): The input prompt.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated text.
"""
input_prompt = self.prompt_template.format(prompt=prompt)
encoded_input = self.tokenizer(input_prompt, return_tensors='pt')
input_ids = encoded_input.input_ids
with torch.inference_mode():
generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions,
input_ids.shape[1] + max_new_tokens),
start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature,
no_repeat_ngram_size=no_repeat_ngram_size)
decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence]
generated_text = decoded_output[0].split('[/INST]')[1].strip("</s>").strip()
return generated_text
def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float,
temperature: float, no_repeat_ngram_size: int) -> str:
"""
Generate text from a given list of input embeddings using the Mistral model.
Args:
input_embeddings (List[int]): A list of input embeddings.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated text.
"""
s1 = time.time()
input_embeds_tensor = torch.tensor(input_embeddings)
input_embeds_length = input_embeds_tensor.shape[1]
padding_size = padding_ceiling(input_embeds_length)
if padding_size >= self.n_positions:
padding_size = input_embeds_length
padded_input_embeds = input_embeds_tensor
else:
padding_gap = padding_size - input_embeds_length
padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id)
print("ms1 - input_embeds time: ", time.time() - s1)
s2 = time.time()
with torch.inference_mode():
generated_sequence = self.model.sample(padded_input_embeds,
sequence_length=min(self.n_positions, padding_size + max_new_tokens),
start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature,
no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer())
with concurrent.futures.ThreadPoolExecutor() as executor:
decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence))
# decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence]
print("ms2 - decoded_output time: ", time.time() - s2)
generated_text = decoded_output[0].strip("</s>").strip()
return generated_text