MCPClient / app.py
binary1ne's picture
Update app.py
c7b7b1f verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import tiktoken # Use this if the tokenizer is based on tiktoken (for some models)
# Model and Tokenizer loading
model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
# Try loading with AutoTokenizer (this should ideally work with many models)
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
print(f"AutoTokenizer loading failed: {e}")
print("Attempting to use tiktoken directly.")
# If AutoTokenizer fails, try using tiktoken tokenizer explicitly
tokenizer = tiktoken.get_encoding("cl100k_base") # Default encoding for tiktoken
# Load model with float16 precision and auto device mapping
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto", # Automatically place model on GPUs if available
low_cpu_mem_usage=True # Efficient CPU memory usage
)
# Optimized pipeline (created once)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
device_map="auto" # Automatically distribute model layers across devices
)
# Function to clean text from special tokens or unwanted characters
def clean_text(text):
# Clean unwanted tokens and formatting
text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip()
return text
# Generate text using the model
def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty):
formatted_prompt = f"""<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{user_message}<|im_end|>
<|im_start|>assistant
"""
# Generate the response using the model pipeline
outputs = pipe(
formatted_prompt,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id
)
response = outputs[0]["generated_text"]
# Clean and format the response
response = clean_text(response)
return response
# Gradio interface styling (same as before)
css = """
.gradio-container {
max-width: 900px !important;
}
.message-box {
border-radius: 8px;
padding: 12px;
margin-bottom: 12px;
}
.system-box {
background-color: #f0f7ff;
}
.user-box {
background-color: #f5f5f5;
}
.assistant-box {
background-color: #f0fff0;
}
.param-box {
background-color: #fff8f0;
border-radius: 8px;
padding: 12px;
margin-bottom: 12px;
}
button:hover {
background-color: #3a7f7f;
transition: background-color 0.3s ease;
}
"""
# Gradio Blocks layout and functionality
with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface
Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face
""")
# Initialize system_message with a default
system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request."
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("### System Message (AI's Personality/Instructions)")
system_message = gr.Textbox(
value=system_message_default, # Default system message
label="System Message",
lines=3,
elem_classes=["message-box", "system-box"]
)
with gr.Group():
gr.Markdown("### Your Message")
user_message = gr.Textbox(
placeholder="Type your message here...",
label="User Message",
lines=5,
elem_classes=["message-box", "user-box"]
)
with gr.Group(elem_classes=["param-box"]):
gr.Markdown("### Generation Parameters")
max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
with gr.Row():
submit_btn = gr.Button("Generate Response", variant="primary")
clear_btn = gr.Button("Clear All")
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("### Assistant Response")
assistant_response = gr.Textbox(
label="Response",
lines=10,
interactive=False,
elem_classes=["message-box", "assistant-box"]
)
with gr.Group():
gr.Markdown("### Conversation History")
chat_history = gr.Chatbot(
label="Chat History",
height=400,
elem_classes=["message-box"]
)
# Initialize System Message State
system_message_state = gr.State(system_message_default)
# Actions to handle system message and user message
submit_btn.click(
fn=generate_text,
inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
outputs=assistant_response
).then(
lambda s, u, r: [(u, r), ("", "")],
[system_message, user_message, assistant_response],
[chat_history, user_message]
)
# Clear button reset
clear_btn.click(
lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""],
outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history]
)
# Handle system message reset when page is refreshed
user_message.submit(
fn=generate_text,
inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
outputs=assistant_response
).then(
lambda s, u, r: [(u, r), ("", "")],
[system_message, user_message, assistant_response],
[chat_history, user_message]
)
# Reset system message on page refresh (by using state)
system_message.change(
fn=lambda message: message,
inputs=[system_message],
outputs=[system_message_state]
)
if __name__ == "__main__":
demo.launch()