frtyh / app.py
dicksinyass's picture
Update app.py
66043b9 verified
import gradio as gr
from transformers import pipeline, Conversation
import torch
from huggingface_hub import login
# --- Configuration ---
MODEL_CHOICES = [
"mistralai/Mistral-7B-Instruct-v0.2", # Good balance
"meta-llama/Llama-2-70b-chat-hf", # Higher quality, requires HF token, more resources
"mistralai/Mixtral-8x7B-Instruct-v0.1", # Potentially best quality, high resources
"codellama/CodeLlama-70b-Instruct-hf" # Best for code, high resources
]
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# --- Helper Functions ---
def load_model(model_name, hf_token=None):
"""Loads the model and tokenizer, handling authentication."""
try:
if hf_token:
login(token=hf_token)
# Use a pipeline for easier interaction
pipe = pipeline(
"conversational",
model=model_name,
device=DEVICE, # Move to GPU if available
torch_dtype=torch.bfloat16, # Use bfloat16 for faster inference (if supported)
trust_remote_code=True, # Important for custom models
use_flash_attention_2=True, # Use flash attention if available
)
return pipe, "Model loaded successfully!"
except Exception as e:
return None, f"Error loading model: {e}"
def generate_response(prompt, chat_history, model_name, hf_token=None):
"""Generates a response using the conversational pipeline."""
# Use a dictionary to store loaded models for faster switching
if not hasattr(generate_response, "loaded_models"):
generate_response.loaded_models = {}
if model_name not in generate_response.loaded_models:
pipe, load_status = load_model(model_name, hf_token)
if pipe is None:
return load_status, chat_history
generate_response.loaded_models[model_name] = pipe
print(f"Model {model_name} loaded.") # Debugging message
else:
print(f"Using cached model {model_name}.") # Debugging message
pipe = generate_response.loaded_models[model_name]
try:
# Convert Gradio chat history to transformers Conversation format
conversation = Conversation()
for user_message, bot_message in chat_history:
conversation.add_message({"role": "user", "content": user_message})
if bot_message: # Handle case where bot hasn't responded yet
conversation.add_message({"role": "assistant", "content": bot_message})
conversation.add_message({"role": "user", "content": prompt})
# Generation parameters (adjust these!)
generation_kwargs = {
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
"repetition_penalty": 1.1
}
# Generate the response
response = pipe(conversation, **generation_kwargs)
# Extract the bot's response from the Conversation object
bot_response = response.messages[-1]["content"]
# Update the chat history
chat_history.append((prompt, bot_response))
return "", chat_history
except Exception as e:
return f"Error during generation: {e}", chat_history
# --- Gradio Interface ---
with gr.Blocks(title="Chat with a Powerful AI") as iface:
gr.Markdown(
"""
# Chat with Different AI Models
This Space demonstrates a chatbot that allows you to select from different AI models.
Choose a model from the dropdown and start chatting!
"""
)
model_selection = gr.Dropdown(
choices=MODEL_CHOICES,
value=MODEL_CHOICES[0], # Default model
label="Select Model",
info="Choose the AI model you want to chat with."
)
hf_token_input = gr.Textbox(
label="Hugging Face Token (Optional, for gated models)",
type="password",
placeholder="Enter your Hugging Face token (if required)",
)
chatbot = gr.Chatbot(label="Chat History", height=500) # Set a reasonable height
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
clear = gr.ClearButton([msg, chatbot])
msg.submit(
generate_response,
[msg, chatbot, model_selection, hf_token_input],
[msg, chatbot],
)
iface.launch()