llm / app.py
RamishRasool14
chatbot
1da696e
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
from threading import Thread
# Load your fine-tuned model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=".", # Path to your fine-tuned model
max_seq_length=8192,
dtype='bf16',
load_in_4bit=False,
)
FastLanguageModel.for_inference(model) # Enable optimized inference
def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192):
"""Function that returns a generator yielding streaming outputs"""
# Convert history to the format expected by tokenizer
formatted_history = []
for exchange in history:
formatted_history.append({"role": "user", "content": exchange[0]})
if len(exchange) > 1 and exchange[1]:
formatted_history.append({"role": "assistant", "content": exchange[1]})
inputs = tokenizer(
[
tokenizer.apply_chat_template(formatted_history,
tokenize=False,
add_generation_prompt=True),
],
return_tensors="pt",
padding=True,
return_attention_mask=True
).to("cuda")
# Create the streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
# Run generation in a separate thread
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_new_tokens
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def predict(message, history):
# Add user message to history in the format Gradio expects
history = history or []
history.append([message, ""])
# Get the streamer with properly formatted history
streamer = get_streaming_generator(model, tokenizer, history)
# Stream the response
full_response = ""
for text_chunk in streamer:
full_response += text_chunk
# Update the last message with the current full response
history[-1][1] = full_response
yield history
def clear_chat():
return [], ""
# Create the Gradio interface with Markdown support
with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=600,
bubble_full_width=False,
render_markdown=True,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
)
msg = gr.Textbox(
label="Message",
placeholder="Type your message here... (Markdown supported)",
lines=2
)
submit = gr.Button("Submit")
clear = gr.Button("Clear")
# Set up the chat interface with streaming
msg.submit(
predict,
[msg, chatbot],
[chatbot],
api_name="predict"
).then(
lambda: "", None, [msg] # Clear input after submission
)
submit.click(
predict,
[msg, chatbot],
[chatbot]
).then(
lambda: "", None, [msg] # Clear input after submission
)
clear.click(
clear_chat,
None,
[chatbot, msg],
queue=False
)
if __name__ == "__main__":
iface.launch()