Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from unsloth import FastLanguageModel | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| from torch import bfloat16 | |
| # Load your fine-tuned model and tokenizer | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=".", # Path to your fine-tuned model | |
| max_seq_length=8192, | |
| dtype=bfloat16, | |
| load_in_4bit=False, | |
| ) | |
| FastLanguageModel.for_inference(model) # Enable optimized inference | |
| def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192): | |
| """Function that returns a generator yielding streaming outputs""" | |
| # Convert history to the format expected by tokenizer | |
| formatted_history = [] | |
| for exchange in history: | |
| formatted_history.append({"role": "user", "content": exchange[0]}) | |
| if len(exchange) > 1 and exchange[1]: | |
| formatted_history.append({"role": "assistant", "content": exchange[1]}) | |
| inputs = tokenizer( | |
| [ | |
| tokenizer.apply_chat_template(formatted_history, | |
| tokenize=False, | |
| add_generation_prompt=True), | |
| ], | |
| return_tensors="pt", | |
| padding=True, | |
| return_attention_mask=True | |
| ).to("cuda") | |
| # Create the streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| # Run generation in a separate thread | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| return streamer | |
| def predict(message, history): | |
| # Add user message to history in the format Gradio expects | |
| history = history or [] | |
| history.append([message, ""]) | |
| # Get the streamer with properly formatted history | |
| streamer = get_streaming_generator(model, tokenizer, history) | |
| # Stream the response | |
| full_response = "" | |
| for text_chunk in streamer: | |
| full_response += text_chunk | |
| # Update the last message with the current full response | |
| history[-1][1] = full_response | |
| yield history | |
| def clear_chat(): | |
| return [], "" | |
| # Create the Gradio interface with Markdown support | |
| with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface: | |
| chatbot = gr.Chatbot( | |
| show_label=False, | |
| container=True, | |
| height=600, | |
| bubble_full_width=False, | |
| render_markdown=True, | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ], | |
| ) | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here... (Markdown supported)", | |
| lines=2 | |
| ) | |
| submit = gr.Button("Submit") | |
| clear = gr.Button("Clear") | |
| # Set up the chat interface with streaming | |
| msg.submit( | |
| predict, | |
| [msg, chatbot], | |
| [chatbot], | |
| api_name="predict" | |
| ).then( | |
| lambda: "", None, [msg] # Clear input after submission | |
| ) | |
| submit.click( | |
| predict, | |
| [msg, chatbot], | |
| [chatbot] | |
| ).then( | |
| lambda: "", None, [msg] # Clear input after submission | |
| ) | |
| clear.click( | |
| clear_chat, | |
| None, | |
| [chatbot, msg], | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |