| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| import torch |
| import tiktoken |
|
|
| |
| model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b" |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| except Exception as e: |
| print(f"AutoTokenizer loading failed: {e}") |
| print("Attempting to use tiktoken directly.") |
| |
| tokenizer = tiktoken.get_encoding("cl100k_base") |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
|
|
| |
| pipe = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
|
|
| |
| def clean_text(text): |
| |
| text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip() |
| return text |
|
|
| |
| def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty): |
| formatted_prompt = f"""<|im_start|>system |
| {system_message}<|im_end|> |
| <|im_start|>user |
| {user_message}<|im_end|> |
| <|im_start|>assistant |
| """ |
| |
| |
| outputs = pipe( |
| formatted_prompt, |
| max_new_tokens=max_length, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| |
| response = outputs[0]["generated_text"] |
| |
| |
| response = clean_text(response) |
| |
| return response |
|
|
| |
| css = """ |
| .gradio-container { |
| max-width: 900px !important; |
| } |
| .message-box { |
| border-radius: 8px; |
| padding: 12px; |
| margin-bottom: 12px; |
| } |
| .system-box { |
| background-color: #f0f7ff; |
| } |
| .user-box { |
| background-color: #f5f5f5; |
| } |
| .assistant-box { |
| background-color: #f0fff0; |
| } |
| .param-box { |
| background-color: #fff8f0; |
| border-radius: 8px; |
| padding: 12px; |
| margin-bottom: 12px; |
| } |
| button:hover { |
| background-color: #3a7f7f; |
| transition: background-color 0.3s ease; |
| } |
| """ |
|
|
| |
| with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo: |
| gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface |
| Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face |
| """) |
| |
| |
| system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request." |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| with gr.Group(): |
| gr.Markdown("### System Message (AI's Personality/Instructions)") |
| system_message = gr.Textbox( |
| value=system_message_default, |
| label="System Message", |
| lines=3, |
| elem_classes=["message-box", "system-box"] |
| ) |
| |
| with gr.Group(): |
| gr.Markdown("### Your Message") |
| user_message = gr.Textbox( |
| placeholder="Type your message here...", |
| label="User Message", |
| lines=5, |
| elem_classes=["message-box", "user-box"] |
| ) |
| |
| with gr.Group(elem_classes=["param-box"]): |
| gr.Markdown("### Generation Parameters") |
| max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length") |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") |
| top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") |
| repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") |
| |
| with gr.Row(): |
| submit_btn = gr.Button("Generate Response", variant="primary") |
| clear_btn = gr.Button("Clear All") |
| |
| with gr.Column(scale=3): |
| with gr.Group(): |
| gr.Markdown("### Assistant Response") |
| assistant_response = gr.Textbox( |
| label="Response", |
| lines=10, |
| interactive=False, |
| elem_classes=["message-box", "assistant-box"] |
| ) |
| |
| with gr.Group(): |
| gr.Markdown("### Conversation History") |
| chat_history = gr.Chatbot( |
| label="Chat History", |
| height=400, |
| elem_classes=["message-box"] |
| ) |
| |
| |
| system_message_state = gr.State(system_message_default) |
| |
| |
| submit_btn.click( |
| fn=generate_text, |
| inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], |
| outputs=assistant_response |
| ).then( |
| lambda s, u, r: [(u, r), ("", "")], |
| [system_message, user_message, assistant_response], |
| [chat_history, user_message] |
| ) |
| |
| |
| clear_btn.click( |
| lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""], |
| outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history] |
| ) |
|
|
| |
| user_message.submit( |
| fn=generate_text, |
| inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], |
| outputs=assistant_response |
| ).then( |
| lambda s, u, r: [(u, r), ("", "")], |
| [system_message, user_message, assistant_response], |
| [chat_history, user_message] |
| ) |
|
|
| |
| system_message.change( |
| fn=lambda message: message, |
| inputs=[system_message], |
| outputs=[system_message_state] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|