Spaces:
Build error
Build error
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| model_id = 'HuggingFaceH4/zephyr-7b-beta' | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map='auto') | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.use_default_system_prompt = False | |
| def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2): | |
| conversation = [] | |
| if system_prompt: | |
| conversation.append({ | |
| 'role': 'system', | |
| 'content': system_prompt | |
| }) | |
| for user, assistant in chat_history: | |
| conversation.extend({ | |
| 'role': 'user', | |
| 'content': user | |
| }, | |
| { | |
| 'role': 'assistant', | |
| 'content': assistant | |
| }) | |
| conversation.append({ | |
| 'role': 'user', | |
| 'content': input | |
| }) | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt') | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {'input_ids': input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield ''.join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| examples=[ | |
| 'What is GPT?', | |
| 'What is Life?', | |
| 'Who is Alan Turing' | |
| ] | |
| ) | |
| chat_interface.queue(max_size=20).launch() |