import gradio as gr from unsloth import FastLanguageModel from transformers import TextIteratorStreamer from threading import Thread # Load your fine-tuned model and tokenizer model, tokenizer = FastLanguageModel.from_pretrained( model_name=".", # Path to your fine-tuned model max_seq_length=8192, dtype='bf16', load_in_4bit=False, ) FastLanguageModel.for_inference(model) # Enable optimized inference def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192): """Function that returns a generator yielding streaming outputs""" # Convert history to the format expected by tokenizer formatted_history = [] for exchange in history: formatted_history.append({"role": "user", "content": exchange[0]}) if len(exchange) > 1 and exchange[1]: formatted_history.append({"role": "assistant", "content": exchange[1]}) inputs = tokenizer( [ tokenizer.apply_chat_template(formatted_history, tokenize=False, add_generation_prompt=True), ], return_tensors="pt", padding=True, return_attention_mask=True ).to("cuda") # Create the streamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) # Run generation in a separate thread generation_kwargs = dict( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], streamer=streamer, max_new_tokens=max_new_tokens ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() return streamer def predict(message, history): # Add user message to history in the format Gradio expects history = history or [] history.append([message, ""]) # Get the streamer with properly formatted history streamer = get_streaming_generator(model, tokenizer, history) # Stream the response full_response = "" for text_chunk in streamer: full_response += text_chunk # Update the last message with the current full response history[-1][1] = full_response yield history def clear_chat(): return [], "" # Create the Gradio interface with Markdown support with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface: chatbot = gr.Chatbot( show_label=False, container=True, height=600, bubble_full_width=False, render_markdown=True, latex_delimiters=[ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, ], ) msg = gr.Textbox( label="Message", placeholder="Type your message here... (Markdown supported)", lines=2 ) submit = gr.Button("Submit") clear = gr.Button("Clear") # Set up the chat interface with streaming msg.submit( predict, [msg, chatbot], [chatbot], api_name="predict" ).then( lambda: "", None, [msg] # Clear input after submission ) submit.click( predict, [msg, chatbot], [chatbot] ).then( lambda: "", None, [msg] # Clear input after submission ) clear.click( clear_chat, None, [chatbot, msg], queue=False ) if __name__ == "__main__": iface.launch()