import re from statistics import mode import numpy as np import torch from transformers import ( LlamaForCausalLM, LlamaTokenizer, pipeline, ) from utils.prompt import getPrompt from utils.tools import Discretizer, Serializer class ChatTime: def __init__(self, model_path, hist_len=None, pred_len=None, max_pred_len=16, num_samples=8, top_k=100, top_p=1.0, temperature=1.0): self.model_path = model_path self.hist_len = hist_len self.pred_len = pred_len self.max_pred_len = max_pred_len self.num_samples = num_samples self.top_k = top_k self.top_p = top_p self.temperature = temperature self.discretizer = Discretizer() self.serializer = Serializer() self.model = LlamaForCausalLM.from_pretrained( self.model_path, low_cpu_mem_usage=True, return_dict=True, torch_dtype=torch.float16, device_map="auto", ) self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path, trust_remote_code=True) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" self.eos_token_id = self.tokenizer.eos_token_id def predict(self, hist_data, context=None): if self.hist_len is None or self.pred_len is None: raise ValueError("hist_len and pred_len must be specified before prediction") series = hist_data prediction_list = [] remaining = self.pred_len while remaining > 0: dispersed_series = self.discretizer.discretize(series) serialized_series = self.serializer.serialize(dispersed_series) serialized_series = getPrompt(flag="prediction", context=context, input=serialized_series) pipe = pipeline( task="text-generation", model=self.model, tokenizer=self.tokenizer, min_new_tokens=2 * min(remaining, self.max_pred_len) + 8, max_new_tokens=2 * min(remaining, self.max_pred_len) + 8, do_sample=True, num_return_sequences=self.num_samples, top_k=self.top_k, top_p=self.top_p, temperature=self.temperature, eos_token_id=self.eos_token_id, ) samples = pipe(serialized_series) pred_list = [] for sample in samples: serialized_prediction = sample["generated_text"].split("### Response:\n")[1] dispersed_prediction = self.serializer.inverse_serialize(serialized_prediction) pred = self.discretizer.inverse_discretize(dispersed_prediction) if len(pred) < min(remaining, self.max_pred_len): pred = np.concatenate([pred, np.full(min(remaining, self.max_pred_len) - len(pred), np.NaN)]) pred_list.append(pred[:min(remaining, self.max_pred_len)]) prediction = np.nanmedian(pred_list, axis=0) prediction_list.append(prediction) remaining -= prediction.shape[-1] if remaining <= 0: break series = np.concatenate([series, prediction], axis=-1) prediction = np.concatenate(prediction_list, axis=-1) return prediction def analyze(self, question, series): dispersed_series = self.discretizer.discretize(series) serialized_series = self.serializer.serialize(dispersed_series) serialized_series = getPrompt(flag="analysis", instruction=question, input=serialized_series) pipe = pipeline( task="text-generation", model=self.model, tokenizer=self.tokenizer, max_new_tokens=self.max_pred_len, do_sample=True, num_return_sequences=self.num_samples, top_k=self.top_k, top_p=self.top_p, temperature=self.temperature, eos_token_id=self.eos_token_id, ) samples = pipe(serialized_series) response_list = [] for sample in samples: response = sample["generated_text"].split("### Response:\n")[1].split('.')[0] + "." response = re.findall(r"\([abc]\)", response)[0] response_list.append(response) response = mode(response_list) return response