Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import time | |
| import os | |
| from typing import List, Dict | |
| class ChatbotHandler: | |
| def __init__(self): | |
| self.model_name = "facebook/opt-6.7b" # Smaller, faster 6.7B model instead of 13B | |
| self.tokenizer = None | |
| self.model = None | |
| self.chat_pipeline = None | |
| self.max_length = 512 # Reduced for speed | |
| self.temperature = 0.7 | |
| self.model_loaded = False | |
| self.system_prompt = """You are a helpful, friendly, and knowledgeable AI assistant. | |
| You provide clear, accurate, and thoughtful responses. You are engaging and try to be | |
| helpful while being honest about your limitations. Always maintain a positive and | |
| supportive tone in your conversations.""" | |
| # Initialize the model | |
| self.initialize_model() | |
| def initialize_model(self): | |
| """Initialize the Hugging Face model with quantization for speed.""" | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
| import torch | |
| except ImportError: | |
| print("Transformers library not available. Please install the required dependencies.") | |
| return False | |
| try: | |
| print("Loading OPT-6.7B model with 8-bit quantization... This should be faster.") | |
| # Configure 8-bit quantization for speed | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_enable_fp32_cpu_offload=True | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| quantization_config=quantization_config, | |
| device_map="auto", # Automatically distribute across available GPUs | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Set pad token if not present | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Create pipeline for text generation with optimized settings | |
| self.chat_pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device_map="auto", | |
| max_length=self.max_length, | |
| temperature=self.temperature, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| truncation=True, | |
| use_fast=True | |
| ) | |
| print("Model loaded successfully!") | |
| self.model_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| return False | |
| def get_response(self, message: str, history: List[Dict]) -> str: | |
| """Get response from the model with optimized settings.""" | |
| if not self.chat_pipeline: | |
| return "Model not loaded. Please try again later." | |
| try: | |
| # Prepare conversation history as a single string (limit to last 2 exchanges for speed) | |
| conversation = self.system_prompt + "\n" | |
| # Add recent history (limit to last 2 exchanges for speed) | |
| for msg in history[-2:]: | |
| if msg["role"] == "user": | |
| conversation += f"User: {msg['content']}\n" | |
| elif msg["role"] == "assistant": | |
| conversation += f"Assistant: {msg['content']}\n" | |
| # Add current message | |
| conversation += f"User: {message}\nAssistant:" | |
| # Generate response with optimized settings for speed | |
| start_time = time.time() | |
| outputs = self.chat_pipeline( | |
| conversation, | |
| max_new_tokens=50, # Shorter responses for speed | |
| num_return_sequences=1, | |
| return_full_text=False, | |
| do_sample=True, | |
| temperature=self.temperature, | |
| top_p=0.9, # Add top_p for better quality | |
| repetition_penalty=1.1 # Reduce repetition | |
| ) | |
| end_time = time.time() | |
| print(f"Response generated in {end_time - start_time:.2f} seconds") | |
| response = outputs[0]['generated_text'].strip() | |
| # Clean up response (remove any unwanted prefixes) | |
| if response.startswith("Assistant:"): | |
| response = response[10:].strip() | |
| elif response.startswith("User:"): | |
| response = "I apologize, but I seem to have gotten confused. How can I help you?" | |
| # Limit response length for speed | |
| if len(response) > 200: | |
| response = response[:200] + "..." | |
| # Faster streaming (yield larger chunks) | |
| words = response.split() | |
| current_response = "" | |
| chunk_size = 3 # Yield every 3 words for faster streaming | |
| for i in range(0, len(words), chunk_size): | |
| chunk = words[i:i + chunk_size] | |
| current_response += " ".join(chunk) + " " | |
| yield current_response.strip() | |
| time.sleep(0.01) # Very short delay for smooth streaming | |
| except Exception as e: | |
| yield f"I apologize, but I encountered an error. Please try again. Error: {str(e)}" | |
| # Initialize chatbot handler | |
| chat_handler = ChatbotHandler() | |
| def respond_stream(message: str, history: List[Dict]): | |
| """Generate streaming response from the model with fixed history management.""" | |
| if not message.strip(): | |
| return "", history | |
| # Create a copy of history to avoid mutation issues | |
| current_history = history.copy() | |
| # Always add user message first to prevent disappearing chats | |
| current_history.append({"role": "user", "content": message}) | |
| # Check if model is initialized | |
| if not chat_handler.chat_pipeline: | |
| current_history.append({"role": "assistant", "content": "The chatbot model is still loading. Please wait a moment and try again."}) | |
| return "", current_history | |
| # Get streaming response with error handling | |
| full_response = "" | |
| assistant_added = False | |
| try: | |
| for chunk in chat_handler.get_response(message, current_history[:-1]): # Don't include current user message in context | |
| full_response = chunk | |
| # Update or add the assistant message | |
| if not assistant_added: | |
| current_history.append({"role": "assistant", "content": full_response}) | |
| assistant_added = True | |
| else: | |
| current_history[-1]["content"] = full_response | |
| yield "", current_history | |
| except Exception as e: | |
| # If streaming fails, add a fallback response | |
| error_msg = "I apologize, but I encountered an error. Please try again." | |
| if not assistant_added: | |
| current_history.append({"role": "assistant", "content": error_msg}) | |
| else: | |
| current_history[-1]["content"] = error_msg | |
| yield "", current_history | |
| def clear_history(): | |
| """Clear the chat history.""" | |
| return [] | |
| def update_model_settings(temp, max_len): | |
| """Update model settings.""" | |
| chat_handler.temperature = temp | |
| chat_handler.max_length = max_len | |
| return f"Settings updated: temp={temp}, max_length={max_len}" | |
| # Create the interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Fast AI Chatbot with OPT-6.7B") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style='text-align: center; padding: 20px;'> | |
| <h1>⚡ Fast AI Chatbot</h1> | |
| <p style='color: #666;'>Powered by OPT-6.7B with 8-bit quantization • Built with <a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank' style='color: #007bff; text-decoration: none;'>anycoder</a></p> | |
| </div> | |
| """) | |
| # Status indicator | |
| if chat_handler.model_loaded: | |
| status_msg = "✅ Chatbot is ready! Responses should take 1-3 seconds." | |
| status_color = "#28a745" | |
| else: | |
| status_msg = "⏳ Loading OPT-6.7B model with quantization... Should be faster than before." | |
| status_color = "#ffc107" | |
| gr.HTML(f""" | |
| <div style='text-align: center; padding: 10px; background-color: {status_color}15; border: 1px solid {status_color}30; border-radius: 5px; margin: 10px 0;'> | |
| <p style='color: {status_color}; margin: 0;'>{status_msg}</p> | |
| </div> | |
| """) | |
| # Model settings | |
| with gr.Accordion("Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values make responses more creative" | |
| ) | |
| max_length = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Max Length", | |
| info="Maximum context length (lower = faster)" | |
| ) | |
| # Chatbot component | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| label="Conversation", | |
| height=500, | |
| show_copy_button=True, | |
| bubble_full_width=False, | |
| avatar_images=(None, "https://huggingface.co/datasets/huggingface/avatars/resolve/main/bot-avatar.png") | |
| ) | |
| # Input section | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here and press Enter...", | |
| container=False, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| # Control buttons | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| refresh_btn = gr.Button("Refresh Settings", variant="secondary") | |
| # Example questions | |
| with gr.Accordion("Example Questions", open=False): | |
| gr.Examples( | |
| examples=[ | |
| "What's the difference between AI and machine learning?", | |
| "Can you explain quantum computing in simple terms?", | |
| "Help me write a professional email.", | |
| "What are some good books to learn programming?", | |
| "Can you help me brainstorm ideas for a project?", | |
| "Explain the concept of blockchain technology." | |
| ], | |
| inputs=msg, | |
| label="Click an example to start chatting" | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style='text-align: center; padding: 10px; color: #888; font-size: 0.9em;'> | |
| <p>This chatbot uses Meta's OPT-6.7B model with 8-bit quantization for fast responses (1-3 seconds). It's completely free to use!</p> | |
| <p><strong>Speed optimizations:</strong> Smaller model, quantization, shorter responses, optimized parameters.</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| # Chat functionality | |
| msg.submit( | |
| respond_stream, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot] | |
| ) | |
| submit_btn.click( | |
| respond_stream, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot] | |
| ) | |
| # Clear chat | |
| clear_btn.click(clear_history, outputs=chatbot) | |
| # Update model settings | |
| temperature.change( | |
| update_model_settings, | |
| inputs=[temperature, max_length], | |
| outputs=[] | |
| ) | |
| max_length.change( | |
| update_model_settings, | |
| inputs=[temperature, max_length], | |
| outputs=[] | |
| ) | |
| # Refresh settings (useful for debugging) | |
| refresh_btn.click( | |
| lambda: f"Settings: temp={chat_handler.temperature}, max_length={chat_handler.max_length}", | |
| outputs=[] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |