File size: 1,631 Bytes
428ef01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# generate.py
import torch
from transformers import TextIteratorStreamer
from load_model import get_model
import threading
def generate_response(
user_prompt: str,
system_prompt: str = "You are a helpful AI assistant.",
max_tokens: int = 32768,
stream: bool = False,
) -> str:
"""Generate response using ALREADY LOADED model"""
model, tokenizer = get_model() # Fast - no loading!
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
)
if stream:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for text in streamer:
yield text
else:
with torch.no_grad():
outputs = model.generate(**gen_kwargs)
return tokenizer.decode(
outputs[0][input_ids.shape[1]:],
skip_special_tokens=True,
).strip()
|