ChatTime / model /model.py
a12354's picture
Add files using upload-large-folder tool
8d2b389 verified
Raw
History Blame Contribute Delete
4.44 kB
import re
from statistics import mode
import numpy as np
import torch
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
pipeline,
)
from utils.prompt import getPrompt
from utils.tools import Discretizer, Serializer
class ChatTime:
def __init__(self, model_path, hist_len=None, pred_len=None,
max_pred_len=16, num_samples=8, top_k=100, top_p=1.0, temperature=1.0):
self.model_path = model_path
self.hist_len = hist_len
self.pred_len = pred_len
self.max_pred_len = max_pred_len
self.num_samples = num_samples
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.discretizer = Discretizer()
self.serializer = Serializer()
self.model = LlamaForCausalLM.from_pretrained(
self.model_path,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
)
self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "right"
self.eos_token_id = self.tokenizer.eos_token_id
def predict(self, hist_data, context=None):
if self.hist_len is None or self.pred_len is None:
raise ValueError("hist_len and pred_len must be specified before prediction")
series = hist_data
prediction_list = []
remaining = self.pred_len
while remaining > 0:
dispersed_series = self.discretizer.discretize(series)
serialized_series = self.serializer.serialize(dispersed_series)
serialized_series = getPrompt(flag="prediction", context=context, input=serialized_series)
pipe = pipeline(
task="text-generation",
model=self.model,
tokenizer=self.tokenizer,
min_new_tokens=2 * min(remaining, self.max_pred_len) + 8,
max_new_tokens=2 * min(remaining, self.max_pred_len) + 8,
do_sample=True,
num_return_sequences=self.num_samples,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
eos_token_id=self.eos_token_id,
)
samples = pipe(serialized_series)
pred_list = []
for sample in samples:
serialized_prediction = sample["generated_text"].split("### Response:\n")[1]
dispersed_prediction = self.serializer.inverse_serialize(serialized_prediction)
pred = self.discretizer.inverse_discretize(dispersed_prediction)
if len(pred) < min(remaining, self.max_pred_len):
pred = np.concatenate([pred, np.full(min(remaining, self.max_pred_len) - len(pred), np.NaN)])
pred_list.append(pred[:min(remaining, self.max_pred_len)])
prediction = np.nanmedian(pred_list, axis=0)
prediction_list.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
series = np.concatenate([series, prediction], axis=-1)
prediction = np.concatenate(prediction_list, axis=-1)
return prediction
def analyze(self, question, series):
dispersed_series = self.discretizer.discretize(series)
serialized_series = self.serializer.serialize(dispersed_series)
serialized_series = getPrompt(flag="analysis", instruction=question, input=serialized_series)
pipe = pipeline(
task="text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=self.max_pred_len,
do_sample=True,
num_return_sequences=self.num_samples,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
eos_token_id=self.eos_token_id,
)
samples = pipe(serialized_series)
response_list = []
for sample in samples:
response = sample["generated_text"].split("### Response:\n")[1].split('.')[0] + "."
response = re.findall(r"\([abc]\)", response)[0]
response_list.append(response)
response = mode(response_list)
return response