base2 / generate.py
Ashok75's picture
Upload 28 files
428ef01 verified
# generate.py
import torch
from transformers import TextIteratorStreamer
from load_model import get_model
import threading
def generate_response(
user_prompt: str,
system_prompt: str = "You are a helpful AI assistant.",
max_tokens: int = 32768,
stream: bool = False,
) -> str:
"""Generate response using ALREADY LOADED model"""
model, tokenizer = get_model() # Fast - no loading!
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
gen_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95,
)
if stream:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for text in streamer:
yield text
else:
with torch.no_grad():
outputs = model.generate(**gen_kwargs)
return tokenizer.decode(
outputs[0][input_ids.shape[1]:],
skip_special_tokens=True,
).strip()