Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria | |
| from threading import Thread | |
| base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device=device) | |
| def format_prompt(message, history): | |
| prompt = "<|system|>\nYou are TinyLlama, a friendly AI assistant.</s>" | |
| for user_prompt, bot_response in history: | |
| prompt += f"\n<|user|>\n{user_prompt}</s>" | |
| prompt += f"\n<|assistant|>\n{bot_response}</s>" | |
| prompt += f"\n<|user|>\n{message}</s>\n<|assistant|>\n" | |
| return prompt | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| stop_ids = [2] | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def generate(prompt, history): | |
| formatted_prompt = format_prompt(prompt, history) | |
| input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device) | |
| stop_criteria = StoppingCriteriaList([StopOnTokens()]) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50, | |
| temperature=0.5, num_beams=1, stopping_criteria=stop_criteria) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs ) | |
| thread.start() | |
| generated_text = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| if '</s>' in generated_text: | |
| break | |
| yield generated_text | |
| mychatbot = gr.Chatbot( | |
| avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,) | |
| demo = gr.ChatInterface(fn=generate, | |
| chatbot=mychatbot, | |
| title=" Tomoniai's Tinyllama Chat ", | |
| description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.", | |
| retry_btn=None, | |
| undo_btn=None | |
| ) | |
| demo.queue().launch(show_api=False) | |