File size: 5,646 Bytes
221a3db
875ba81
 
221a3db
875ba81
 
 
 
 
 
 
 
221a3db
875ba81
 
 
 
 
 
 
 
221a3db
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876196b
875ba81
 
876196b
875ba81
221a3db
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221a3db
875ba81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876196b
875ba81
876196b
875ba81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

# Load the model and tokenizer
model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Create a text generation pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto"
)

def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty):
    # Format the prompt with the custom system message
    formatted_prompt = f"""<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
"""
    
    # Generate the response
    outputs = pipe(
        formatted_prompt,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Extract the generated text
    response = outputs[0]["generated_text"]
    
    # Remove the prompt from the response
    response = response[len(formatted_prompt):].strip()
    
    return response

# CSS for better appearance
css = """
.gradio-container {
    max-width: 900px !important;
}
.message-box {
    border-radius: 8px;
    padding: 12px;
    margin-bottom: 12px;
}
.system-box {
    background-color: #f0f7ff;
}
.user-box {
    background-color: #f5f5f5;
}
.assistant-box {
    background-color: #f0fff0;
}
.param-box {
    background-color: #fff8f0;
    border-radius: 8px;
    padding: 12px;
    margin-bottom: 12px;
}
"""

with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
    gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface
    Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            # System Message
            with gr.Group():
                gr.Markdown("### System Message (AI's Personality/Instructions)")
                system_message = gr.Textbox(
                    value="You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request.",
                    label="System Message",
                    lines=3,
                    elem_classes=["message-box", "system-box"]
                )
            
            # User Message
            with gr.Group():
                gr.Markdown("### Your Message")
                user_message = gr.Textbox(
                    placeholder="Type your message here...",
                    label="User Message",
                    lines=5,
                    elem_classes=["message-box", "user-box"]
                )
            
            # Generation Parameters
            with gr.Group(elem_classes=["param-box"]):
                gr.Markdown("### Generation Parameters")
                with gr.Row():
                    max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length")
                    temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
                with gr.Row():
                    top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
                    top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
                with gr.Row():
                    repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
            
            # Buttons
            with gr.Row():
                submit_btn = gr.Button("Generate Response", variant="primary")
                clear_btn = gr.Button("Clear All")
        
        with gr.Column(scale=3):
            # Assistant Response
            with gr.Group():
                gr.Markdown("### Assistant Response")
                assistant_response = gr.Textbox(
                    label="Response",
                    lines=10,
                    interactive=False,
                    elem_classes=["message-box", "assistant-box"]
                )
            
            # Chat History
            with gr.Group():
                gr.Markdown("### Conversation History")
                chat_history = gr.Chatbot(
                    label="Chat History",
                    height=400,
                    elem_classes=["message-box"]
                )
    
    # Button actions
    submit_btn.click(
        fn=generate_text,
        inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
        outputs=assistant_response
    ).then(
        lambda s, u, r: [(u, r), ("", "")],
        [system_message, user_message, assistant_response],
        [chat_history, user_message]
    )
    
    clear_btn.click(
        lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""],
        outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history, assistant_response]
    )
    
    # Allow submitting with Enter key
    user_message.submit(
        fn=generate_text,
        inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
        outputs=assistant_response
    ).then(
        lambda s, u, r: [(u, r), ("", "")],
        [system_message, user_message, assistant_response],
        [chat_history, user_message]
    )

# Run the app
if __name__ == "__main__":
    demo.launch()