| | import gradio as gr |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | import torch |
| | import tiktoken |
| |
|
| | |
| | model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b" |
| |
|
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | except Exception as e: |
| | print(f"AutoTokenizer loading failed: {e}") |
| | print("Attempting to use tiktoken directly.") |
| | |
| | tokenizer = tiktoken.get_encoding("cl100k_base") |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | low_cpu_mem_usage=True |
| | ) |
| |
|
| | |
| | pipe = pipeline( |
| | "text-generation", |
| | model=model, |
| | tokenizer=tokenizer, |
| | torch_dtype=torch.float16, |
| | device_map="auto" |
| | ) |
| |
|
| | |
| | def clean_text(text): |
| | |
| | text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip() |
| | return text |
| |
|
| | |
| | def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty): |
| | formatted_prompt = f"""<|im_start|>system |
| | {system_message}<|im_end|> |
| | <|im_start|>user |
| | {user_message}<|im_end|> |
| | <|im_start|>assistant |
| | """ |
| | |
| | |
| | outputs = pipe( |
| | formatted_prompt, |
| | max_new_tokens=max_length, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | top_k=top_k, |
| | repetition_penalty=repetition_penalty, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| | |
| | response = outputs[0]["generated_text"] |
| | |
| | |
| | response = clean_text(response) |
| | |
| | return response |
| |
|
| | |
| | css = """ |
| | .gradio-container { |
| | max-width: 900px !important; |
| | } |
| | .message-box { |
| | border-radius: 8px; |
| | padding: 12px; |
| | margin-bottom: 12px; |
| | } |
| | .system-box { |
| | background-color: #f0f7ff; |
| | } |
| | .user-box { |
| | background-color: #f5f5f5; |
| | } |
| | .assistant-box { |
| | background-color: #f0fff0; |
| | } |
| | .param-box { |
| | background-color: #fff8f0; |
| | border-radius: 8px; |
| | padding: 12px; |
| | margin-bottom: 12px; |
| | } |
| | button:hover { |
| | background-color: #3a7f7f; |
| | transition: background-color 0.3s ease; |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo: |
| | gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface |
| | Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face |
| | """) |
| | |
| | |
| | system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request." |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | with gr.Group(): |
| | gr.Markdown("### System Message (AI's Personality/Instructions)") |
| | system_message = gr.Textbox( |
| | value=system_message_default, |
| | label="System Message", |
| | lines=3, |
| | elem_classes=["message-box", "system-box"] |
| | ) |
| | |
| | with gr.Group(): |
| | gr.Markdown("### Your Message") |
| | user_message = gr.Textbox( |
| | placeholder="Type your message here...", |
| | label="User Message", |
| | lines=5, |
| | elem_classes=["message-box", "user-box"] |
| | ) |
| | |
| | with gr.Group(elem_classes=["param-box"]): |
| | gr.Markdown("### Generation Parameters") |
| | max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length") |
| | temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") |
| | top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") |
| | top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") |
| | repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") |
| | |
| | with gr.Row(): |
| | submit_btn = gr.Button("Generate Response", variant="primary") |
| | clear_btn = gr.Button("Clear All") |
| | |
| | with gr.Column(scale=3): |
| | with gr.Group(): |
| | gr.Markdown("### Assistant Response") |
| | assistant_response = gr.Textbox( |
| | label="Response", |
| | lines=10, |
| | interactive=False, |
| | elem_classes=["message-box", "assistant-box"] |
| | ) |
| | |
| | with gr.Group(): |
| | gr.Markdown("### Conversation History") |
| | chat_history = gr.Chatbot( |
| | label="Chat History", |
| | height=400, |
| | elem_classes=["message-box"] |
| | ) |
| | |
| | |
| | system_message_state = gr.State(system_message_default) |
| | |
| | |
| | submit_btn.click( |
| | fn=generate_text, |
| | inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], |
| | outputs=assistant_response |
| | ).then( |
| | lambda s, u, r: [(u, r), ("", "")], |
| | [system_message, user_message, assistant_response], |
| | [chat_history, user_message] |
| | ) |
| | |
| | |
| | clear_btn.click( |
| | lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""], |
| | outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history] |
| | ) |
| |
|
| | |
| | user_message.submit( |
| | fn=generate_text, |
| | inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], |
| | outputs=assistant_response |
| | ).then( |
| | lambda s, u, r: [(u, r), ("", "")], |
| | [system_message, user_message, assistant_response], |
| | [chat_history, user_message] |
| | ) |
| |
|
| | |
| | system_message.change( |
| | fn=lambda message: message, |
| | inputs=[system_message], |
| | outputs=[system_message_state] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|