|
|
| import torch
|
| from transformers import TextIteratorStreamer
|
| from load_model import get_model
|
| import threading
|
|
|
| def generate_response(
|
| user_prompt: str,
|
| system_prompt: str = "You are a helpful AI assistant.",
|
| max_tokens: int = 32768,
|
| stream: bool = False,
|
| ) -> str:
|
| """Generate response using ALREADY LOADED model"""
|
| model, tokenizer = get_model()
|
|
|
| messages = [
|
| {"role": "system", "content": system_prompt},
|
| {"role": "user", "content": user_prompt},
|
| ]
|
|
|
| input_ids = tokenizer.apply_chat_template(
|
| messages,
|
| add_generation_prompt=True,
|
| return_tensors="pt",
|
| ).to(model.device)
|
|
|
| attention_mask = (input_ids != tokenizer.pad_token_id).long()
|
|
|
| gen_kwargs = dict(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| max_new_tokens=max_tokens,
|
| pad_token_id=tokenizer.eos_token_id,
|
| use_cache=False,
|
| do_sample=True,
|
| temperature=0.8,
|
| top_k=50,
|
| top_p=0.95,
|
| )
|
|
|
| if stream:
|
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| gen_kwargs["streamer"] = streamer
|
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
| thread.start()
|
| for text in streamer:
|
| yield text
|
| else:
|
| with torch.no_grad():
|
| outputs = model.generate(**gen_kwargs)
|
|
|
| return tokenizer.decode(
|
| outputs[0][input_ids.shape[1]:],
|
| skip_special_tokens=True,
|
| ).strip()
|
|
|