|
|
|
|
|
import torch
|
|
|
from transformers import TextIteratorStreamer
|
|
|
from load_model import get_model
|
|
|
import threading
|
|
|
|
|
|
def generate_response(
|
|
|
user_prompt: str,
|
|
|
system_prompt: str = "You are a helpful AI assistant.",
|
|
|
max_tokens: int = 32768,
|
|
|
stream: bool = False,
|
|
|
) -> str:
|
|
|
"""Generate response using ALREADY LOADED model"""
|
|
|
model, tokenizer = get_model()
|
|
|
|
|
|
messages = [
|
|
|
{"role": "system", "content": system_prompt},
|
|
|
{"role": "user", "content": user_prompt},
|
|
|
]
|
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
add_generation_prompt=True,
|
|
|
return_tensors="pt",
|
|
|
).to(model.device)
|
|
|
|
|
|
attention_mask = (input_ids != tokenizer.pad_token_id).long()
|
|
|
|
|
|
gen_kwargs = dict(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
max_new_tokens=max_tokens,
|
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
|
use_cache=False,
|
|
|
do_sample=True,
|
|
|
temperature=0.8,
|
|
|
top_k=50,
|
|
|
top_p=0.95,
|
|
|
)
|
|
|
|
|
|
if stream:
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
gen_kwargs["streamer"] = streamer
|
|
|
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
|
|
thread.start()
|
|
|
for text in streamer:
|
|
|
yield text
|
|
|
else:
|
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(**gen_kwargs)
|
|
|
|
|
|
return tokenizer.decode(
|
|
|
outputs[0][input_ids.shape[1]:],
|
|
|
skip_special_tokens=True,
|
|
|
).strip()
|
|
|
|