File size: 3,451 Bytes
bb87f7a
a8771b3
1da696e
 
a8771b3
 
 
 
 
 
 
 
 
 
1da696e
 
 
 
 
 
 
 
a765a0a
a8771b3
1da696e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8771b3
1da696e
 
 
 
 
 
 
a765a0a
1da696e
 
 
 
 
 
 
 
 
 
a8771b3
a765a0a
9c1566d
a765a0a
 
 
 
 
1da696e
 
9c1566d
 
 
a765a0a
 
 
9c1566d
a765a0a
 
1da696e
a765a0a
 
1da696e
 
 
 
 
 
 
 
 
a765a0a
9c1566d
a765a0a
1da696e
 
 
 
a765a0a
 
1da696e
 
 
 
 
 
bb87f7a
 
a765a0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
from threading import Thread

# Load your fine-tuned model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=".",  # Path to your fine-tuned model
    max_seq_length=8192,
    dtype='bf16',
    load_in_4bit=False,
)
FastLanguageModel.for_inference(model)  # Enable optimized inference

def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192):
    """Function that returns a generator yielding streaming outputs"""
    # Convert history to the format expected by tokenizer
    formatted_history = []
    for exchange in history:
        formatted_history.append({"role": "user", "content": exchange[0]})
        if len(exchange) > 1 and exchange[1]:
            formatted_history.append({"role": "assistant", "content": exchange[1]})
    
    inputs = tokenizer(
        [
            tokenizer.apply_chat_template(formatted_history, 
                                        tokenize=False, 
                                        add_generation_prompt=True),
        ], 
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    ).to("cuda")
    
    # Create the streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    
    # Run generation in a separate thread
    generation_kwargs = dict(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        streamer=streamer,
        max_new_tokens=max_new_tokens
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    return streamer

def predict(message, history):
    # Add user message to history in the format Gradio expects
    history = history or []
    history.append([message, ""])
    
    # Get the streamer with properly formatted history
    streamer = get_streaming_generator(model, tokenizer, history)
    
    # Stream the response
    full_response = ""
    for text_chunk in streamer:
        full_response += text_chunk
        # Update the last message with the current full response
        history[-1][1] = full_response
        yield history

def clear_chat():
    return [], ""

# Create the Gradio interface with Markdown support
with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
    chatbot = gr.Chatbot(
        show_label=False,
        container=True,
        height=600,
        bubble_full_width=False,
        render_markdown=True,
        latex_delimiters=[
            {"left": "$$", "right": "$$", "display": True},
            {"left": "$", "right": "$", "display": False},
        ],
    )
    msg = gr.Textbox(
        label="Message",
        placeholder="Type your message here... (Markdown supported)",
        lines=2
    )
    submit = gr.Button("Submit")
    clear = gr.Button("Clear")
    
    # Set up the chat interface with streaming
    msg.submit(
        predict,
        [msg, chatbot],
        [chatbot],
        api_name="predict"
    ).then(
        lambda: "", None, [msg]  # Clear input after submission
    )
    
    submit.click(
        predict,
        [msg, chatbot],
        [chatbot]
    ).then(
        lambda: "", None, [msg]  # Clear input after submission
    )
    
    clear.click(
        clear_chat,
        None, 
        [chatbot, msg], 
        queue=False
    )

if __name__ == "__main__":
    iface.launch()