Tinyllama_Chat / app.py
Tomoniai's picture
Update app.py
bbe36b7 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria
from threading import Thread
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)
def format_prompt(message, history):
prompt = "<|system|>\nYou are TinyLlama, a friendly AI assistant.</s>"
for user_prompt, bot_response in history:
prompt += f"\n<|user|>\n{user_prompt}</s>"
prompt += f"\n<|assistant|>\n{bot_response}</s>"
prompt += f"\n<|user|>\n{message}</s>\n<|assistant|>\n"
return prompt
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def generate(prompt, history):
formatted_prompt = format_prompt(prompt, history)
input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device)
stop_criteria = StoppingCriteriaList([StopOnTokens()])
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50,
temperature=0.5, num_beams=1, stopping_criteria=stop_criteria)
thread = Thread(target=model.generate, kwargs=generation_kwargs )
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if '</s>' in generated_text:
break
yield generated_text
mychatbot = gr.Chatbot(
avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
demo = gr.ChatInterface(fn=generate,
chatbot=mychatbot,
title=" Tomoniai's Tinyllama Chat ",
description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.",
retry_btn=None,
undo_btn=None
)
demo.queue().launch(show_api=False)