File size: 3,480 Bytes
6e6c496
 
 
 
3220902
6e6c496
 
 
 
 
3220902
6e6c496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
from threading import Thread
from torch import bfloat16

# Load your fine-tuned model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=".",  # Path to your fine-tuned model
    max_seq_length=8192,
    dtype=bfloat16,
    load_in_4bit=False,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192):
    """Function that returns a generator yielding streaming outputs"""
    # Convert history to the format expected by tokenizer
    formatted_history = []
    for exchange in history:
        formatted_history.append({"role": "user", "content": exchange[0]})
        if len(exchange) > 1 and exchange[1]:
            formatted_history.append({"role": "assistant", "content": exchange[1]})
    
    inputs = tokenizer(
        [
            tokenizer.apply_chat_template(formatted_history, 
                                        tokenize=False, 
                                        add_generation_prompt=True),
        ], 
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    ).to("cuda")
    
    # Create the streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    
    # Run generation in a separate thread
    generation_kwargs = dict(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=max_new_tokens
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    return streamer

def predict(message, history):
    # Add user message to history in the format Gradio expects
    history = history or []
    history.append([message, ""])
    
    # Get the streamer with properly formatted history
    streamer = get_streaming_generator(model, tokenizer, history)
    
    # Stream the response
    full_response = ""
    for text_chunk in streamer:
        full_response += text_chunk
        # Update the last message with the current full response
        history[-1][1] = full_response
        yield history

def clear_chat():
    return [], ""

# Create the Gradio interface with Markdown support
with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
    chatbot = gr.Chatbot(
        show_label=False,
        container=True,
        height=600,
        bubble_full_width=False,
        render_markdown=True,
        latex_delimiters=[
            {"left": "$$", "right": "$$", "display": True},
            {"left": "$", "right": "$", "display": False},
        ],
    )
    msg = gr.Textbox(
        label="Message",
        placeholder="Type your message here... (Markdown supported)",
        lines=2
    )
    submit = gr.Button("Submit")
    clear = gr.Button("Clear")
    
    # Set up the chat interface with streaming
    msg.submit(
        predict,
        [msg, chatbot],
        [chatbot],
        api_name="predict"
    ).then(
        lambda: "", None, [msg]  # Clear input after submission
    )
    
    submit.click(
        predict,
        [msg, chatbot],
        [chatbot]
    ).then(
        lambda: "", None, [msg]  # Clear input after submission
    )
    
    clear.click(
        clear_chat,
        None, 
        [chatbot, msg], 
        queue=False
    )

if __name__ == "__main__":
    iface.launch()