Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, Conversation | |
| import torch | |
| from huggingface_hub import login | |
| # --- Configuration --- | |
| MODEL_CHOICES = [ | |
| "mistralai/Mistral-7B-Instruct-v0.2", # Good balance | |
| "meta-llama/Llama-2-70b-chat-hf", # Higher quality, requires HF token, more resources | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", # Potentially best quality, high resources | |
| "codellama/CodeLlama-70b-Instruct-hf" # Best for code, high resources | |
| ] | |
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # --- Helper Functions --- | |
| def load_model(model_name, hf_token=None): | |
| """Loads the model and tokenizer, handling authentication.""" | |
| try: | |
| if hf_token: | |
| login(token=hf_token) | |
| # Use a pipeline for easier interaction | |
| pipe = pipeline( | |
| "conversational", | |
| model=model_name, | |
| device=DEVICE, # Move to GPU if available | |
| torch_dtype=torch.bfloat16, # Use bfloat16 for faster inference (if supported) | |
| trust_remote_code=True, # Important for custom models | |
| use_flash_attention_2=True, # Use flash attention if available | |
| ) | |
| return pipe, "Model loaded successfully!" | |
| except Exception as e: | |
| return None, f"Error loading model: {e}" | |
| def generate_response(prompt, chat_history, model_name, hf_token=None): | |
| """Generates a response using the conversational pipeline.""" | |
| # Use a dictionary to store loaded models for faster switching | |
| if not hasattr(generate_response, "loaded_models"): | |
| generate_response.loaded_models = {} | |
| if model_name not in generate_response.loaded_models: | |
| pipe, load_status = load_model(model_name, hf_token) | |
| if pipe is None: | |
| return load_status, chat_history | |
| generate_response.loaded_models[model_name] = pipe | |
| print(f"Model {model_name} loaded.") # Debugging message | |
| else: | |
| print(f"Using cached model {model_name}.") # Debugging message | |
| pipe = generate_response.loaded_models[model_name] | |
| try: | |
| # Convert Gradio chat history to transformers Conversation format | |
| conversation = Conversation() | |
| for user_message, bot_message in chat_history: | |
| conversation.add_message({"role": "user", "content": user_message}) | |
| if bot_message: # Handle case where bot hasn't responded yet | |
| conversation.add_message({"role": "assistant", "content": bot_message}) | |
| conversation.add_message({"role": "user", "content": prompt}) | |
| # Generation parameters (adjust these!) | |
| generation_kwargs = { | |
| "max_new_tokens": 512, | |
| "do_sample": True, | |
| "top_p": 0.95, | |
| "temperature": 0.7, | |
| "repetition_penalty": 1.1 | |
| } | |
| # Generate the response | |
| response = pipe(conversation, **generation_kwargs) | |
| # Extract the bot's response from the Conversation object | |
| bot_response = response.messages[-1]["content"] | |
| # Update the chat history | |
| chat_history.append((prompt, bot_response)) | |
| return "", chat_history | |
| except Exception as e: | |
| return f"Error during generation: {e}", chat_history | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Chat with a Powerful AI") as iface: | |
| gr.Markdown( | |
| """ | |
| # Chat with Different AI Models | |
| This Space demonstrates a chatbot that allows you to select from different AI models. | |
| Choose a model from the dropdown and start chatting! | |
| """ | |
| ) | |
| model_selection = gr.Dropdown( | |
| choices=MODEL_CHOICES, | |
| value=MODEL_CHOICES[0], # Default model | |
| label="Select Model", | |
| info="Choose the AI model you want to chat with." | |
| ) | |
| hf_token_input = gr.Textbox( | |
| label="Hugging Face Token (Optional, for gated models)", | |
| type="password", | |
| placeholder="Enter your Hugging Face token (if required)", | |
| ) | |
| chatbot = gr.Chatbot(label="Chat History", height=500) # Set a reasonable height | |
| msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
| clear = gr.ClearButton([msg, chatbot]) | |
| msg.submit( | |
| generate_response, | |
| [msg, chatbot, model_selection, hf_token_input], | |
| [msg, chatbot], | |
| ) | |
| iface.launch() |