File size: 6,834 Bytes
221a3db
875ba81
 
c7b7b1f
221a3db
c7b7b1f
875ba81
c7b7b1f
 
 
 
 
 
 
 
 
 
 
875ba81
 
 
c7b7b1f
 
875ba81
221a3db
c7b7b1f
875ba81
 
 
 
 
c7b7b1f
875ba81
221a3db
c7b7b1f
 
 
 
 
 
 
875ba81
 
 
 
 
 
 
 
c7b7b1f
875ba81
 
 
 
 
 
 
 
 
 
 
 
876196b
c7b7b1f
 
876196b
875ba81
221a3db
c7b7b1f
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b7b1f
 
 
 
875ba81
221a3db
c7b7b1f
875ba81
 
 
 
 
c7b7b1f
 
 
875ba81
 
 
 
 
c7b7b1f
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b7b1f
 
 
 
 
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b7b1f
 
 
 
875ba81
 
 
 
 
 
 
 
 
 
c7b7b1f
875ba81
 
c7b7b1f
875ba81
c7b7b1f
 
875ba81
 
 
 
 
 
 
 
 
876196b
c7b7b1f
 
 
 
 
 
 
876196b
c7b7b1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import tiktoken  # Use this if the tokenizer is based on tiktoken (for some models)

# Model and Tokenizer loading
model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"

# Try loading with AutoTokenizer (this should ideally work with many models)
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    print(f"AutoTokenizer loading failed: {e}")
    print("Attempting to use tiktoken directly.")
    # If AutoTokenizer fails, try using tiktoken tokenizer explicitly
    tokenizer = tiktoken.get_encoding("cl100k_base")  # Default encoding for tiktoken

# Load model with float16 precision and auto device mapping
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",  # Automatically place model on GPUs if available
    low_cpu_mem_usage=True  # Efficient CPU memory usage
)

# Optimized pipeline (created once)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto"  # Automatically distribute model layers across devices
)

# Function to clean text from special tokens or unwanted characters
def clean_text(text):
    # Clean unwanted tokens and formatting
    text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip()
    return text

# Generate text using the model
def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty):
    formatted_prompt = f"""<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
"""
    
    # Generate the response using the model pipeline
    outputs = pipe(
        formatted_prompt,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = outputs[0]["generated_text"]
    
    # Clean and format the response
    response = clean_text(response)
    
    return response

# Gradio interface styling (same as before)
css = """
.gradio-container {
    max-width: 900px !important;
}
.message-box {
    border-radius: 8px;
    padding: 12px;
    margin-bottom: 12px;
}
.system-box {
    background-color: #f0f7ff;
}
.user-box {
    background-color: #f5f5f5;
}
.assistant-box {
    background-color: #f0fff0;
}
.param-box {
    background-color: #fff8f0;
    border-radius: 8px;
    padding: 12px;
    margin-bottom: 12px;
}
button:hover {
    background-color: #3a7f7f;
    transition: background-color 0.3s ease;
}
"""

# Gradio Blocks layout and functionality
with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
    gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface
    Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face
    """)
    
    # Initialize system_message with a default
    system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request."
    
    with gr.Row():
        with gr.Column(scale=2):
            with gr.Group():
                gr.Markdown("### System Message (AI's Personality/Instructions)")
                system_message = gr.Textbox(
                    value=system_message_default,  # Default system message
                    label="System Message",
                    lines=3,
                    elem_classes=["message-box", "system-box"]
                )
            
            with gr.Group():
                gr.Markdown("### Your Message")
                user_message = gr.Textbox(
                    placeholder="Type your message here...",
                    label="User Message",
                    lines=5,
                    elem_classes=["message-box", "user-box"]
                )
            
            with gr.Group(elem_classes=["param-box"]):
                gr.Markdown("### Generation Parameters")
                max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length")
                temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
                top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
                top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
                repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
            
            with gr.Row():
                submit_btn = gr.Button("Generate Response", variant="primary")
                clear_btn = gr.Button("Clear All")
        
        with gr.Column(scale=3):
            with gr.Group():
                gr.Markdown("### Assistant Response")
                assistant_response = gr.Textbox(
                    label="Response",
                    lines=10,
                    interactive=False,
                    elem_classes=["message-box", "assistant-box"]
                )
            
            with gr.Group():
                gr.Markdown("### Conversation History")
                chat_history = gr.Chatbot(
                    label="Chat History",
                    height=400,
                    elem_classes=["message-box"]
                )
    
    # Initialize System Message State
    system_message_state = gr.State(system_message_default)
    
    # Actions to handle system message and user message
    submit_btn.click(
        fn=generate_text,
        inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
        outputs=assistant_response
    ).then(
        lambda s, u, r: [(u, r), ("", "")],
        [system_message, user_message, assistant_response],
        [chat_history, user_message]
    )
    
    # Clear button reset
    clear_btn.click(
        lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""],
        outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history]
    )

    # Handle system message reset when page is refreshed
    user_message.submit(
        fn=generate_text,
        inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
        outputs=assistant_response
    ).then(
        lambda s, u, r: [(u, r), ("", "")],
        [system_message, user_message, assistant_response],
        [chat_history, user_message]
    )

    # Reset system message on page refresh (by using state)
    system_message.change(
        fn=lambda message: message,
        inputs=[system_message],
        outputs=[system_message_state]
    )

if __name__ == "__main__":
    demo.launch()