File size: 3,451 Bytes
bb87f7a a8771b3 1da696e a8771b3 1da696e a765a0a a8771b3 1da696e a8771b3 1da696e a765a0a 1da696e a8771b3 a765a0a 9c1566d a765a0a 1da696e 9c1566d a765a0a 9c1566d a765a0a 1da696e a765a0a 1da696e a765a0a 9c1566d a765a0a 1da696e a765a0a 1da696e bb87f7a a765a0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
from threading import Thread
# Load your fine-tuned model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=".", # Path to your fine-tuned model
max_seq_length=8192,
dtype='bf16',
load_in_4bit=False,
)
FastLanguageModel.for_inference(model) # Enable optimized inference
def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192):
"""Function that returns a generator yielding streaming outputs"""
# Convert history to the format expected by tokenizer
formatted_history = []
for exchange in history:
formatted_history.append({"role": "user", "content": exchange[0]})
if len(exchange) > 1 and exchange[1]:
formatted_history.append({"role": "assistant", "content": exchange[1]})
inputs = tokenizer(
[
tokenizer.apply_chat_template(formatted_history,
tokenize=False,
add_generation_prompt=True),
],
return_tensors="pt",
padding=True,
return_attention_mask=True
).to("cuda")
# Create the streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
# Run generation in a separate thread
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_new_tokens
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def predict(message, history):
# Add user message to history in the format Gradio expects
history = history or []
history.append([message, ""])
# Get the streamer with properly formatted history
streamer = get_streaming_generator(model, tokenizer, history)
# Stream the response
full_response = ""
for text_chunk in streamer:
full_response += text_chunk
# Update the last message with the current full response
history[-1][1] = full_response
yield history
def clear_chat():
return [], ""
# Create the Gradio interface with Markdown support
with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=600,
bubble_full_width=False,
render_markdown=True,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
)
msg = gr.Textbox(
label="Message",
placeholder="Type your message here... (Markdown supported)",
lines=2
)
submit = gr.Button("Submit")
clear = gr.Button("Clear")
# Set up the chat interface with streaming
msg.submit(
predict,
[msg, chatbot],
[chatbot],
api_name="predict"
).then(
lambda: "", None, [msg] # Clear input after submission
)
submit.click(
predict,
[msg, chatbot],
[chatbot]
).then(
lambda: "", None, [msg] # Clear input after submission
)
clear.click(
clear_chat,
None,
[chatbot, msg],
queue=False
)
if __name__ == "__main__":
iface.launch() |