| | |
| |
|
| | import os |
| | import warnings |
| | from collections.abc import Iterator |
| | from threading import Thread |
| | from typing import List, Dict, Optional, Tuple |
| | import time |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | try: |
| | import torch |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TextIteratorStreamer |
| | ) |
| | TRANSFORMERS_AVAILABLE = True |
| | except ImportError: |
| | TRANSFORMERS_AVAILABLE = False |
| |
|
| | try: |
| | import gradio as gr |
| | GRADIO_AVAILABLE = True |
| | except ImportError: |
| | GRADIO_AVAILABLE = False |
| |
|
| | class CPULLMChat: |
| | def __init__(self): |
| | self.models = { |
| | "microsoft/DialoGPT-medium": "DialoGPT Medium (Recommended for chat)", |
| | "microsoft/DialoGPT-small": "DialoGPT Small (Faster)", |
| | "distilgpt2": "DistilGPT2 (Very fast)", |
| | "gpt2": "GPT2 (Standard)", |
| | "facebook/blenderbot-400M-distill": "BlenderBot (Conversational)" |
| | } |
| | |
| | self.current_model = None |
| | self.current_tokenizer = None |
| | self.current_model_name = None |
| | self.model_loaded = False |
| | |
| | |
| | self.max_input_length = 2048 |
| | self.device = "cpu" |
| | |
| | def load_model(self, model_name: str, progress=gr.Progress()) -> str: |
| | """Load the selected model""" |
| | if not TRANSFORMERS_AVAILABLE: |
| | return "β Error: transformers library not installed. Run: pip install torch transformers" |
| | |
| | if model_name == self.current_model_name and self.model_loaded: |
| | return f"β
Model {model_name} is already loaded!" |
| | |
| | try: |
| | progress(0.1, desc="Loading tokenizer...") |
| | |
| | |
| | self.current_tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | padding_side="left" |
| | ) |
| | if self.current_tokenizer.pad_token is None: |
| | self.current_tokenizer.pad_token = self.current_tokenizer.eos_token |
| | |
| | progress(0.5, desc="Loading model...") |
| | |
| | |
| | self.current_model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float32, |
| | device_map={"": self.device}, |
| | low_cpu_mem_usage=True |
| | ) |
| | |
| | |
| | self.current_model.eval() |
| | |
| | self.current_model_name = model_name |
| | self.model_loaded = True |
| | |
| | progress(1.0, desc="Model loaded successfully!") |
| | |
| | return f"β
Successfully loaded: {model_name}" |
| | |
| | except Exception as e: |
| | self.model_loaded = False |
| | return f"β Failed to load model {model_name}: {str(e)}" |
| | |
| | def generate_response( |
| | self, |
| | message: str, |
| | chat_history: List[List[str]], |
| | max_new_tokens: int = 256, |
| | temperature: float = 0.7, |
| | top_p: float = 0.9, |
| | top_k: int = 50, |
| | repetition_penalty: float = 1.1, |
| | ) -> Iterator[str]: |
| | """Generate response with streaming""" |
| | |
| | if not self.model_loaded: |
| | yield "β Please load a model first!" |
| | return |
| | |
| | if not message.strip(): |
| | yield "Please enter a message." |
| | return |
| | |
| | try: |
| | |
| | conversation_text = "" |
| | |
| | |
| | recent_history = chat_history[-5:] if len(chat_history) > 5 else chat_history |
| | |
| | if "DialoGPT" in self.current_model_name: |
| | |
| | chat_history_ids = None |
| | |
| | |
| | for user_msg, bot_msg in recent_history: |
| | if user_msg: |
| | user_input_ids = self.current_tokenizer.encode( |
| | user_msg + self.current_tokenizer.eos_token, |
| | return_tensors='pt' |
| | ) |
| | if chat_history_ids is not None: |
| | chat_history_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1) |
| | else: |
| | chat_history_ids = user_input_ids |
| | |
| | if bot_msg: |
| | bot_input_ids = self.current_tokenizer.encode( |
| | bot_msg + self.current_tokenizer.eos_token, |
| | return_tensors='pt' |
| | ) |
| | if chat_history_ids is not None: |
| | chat_history_ids = torch.cat([chat_history_ids, bot_input_ids], dim=-1) |
| | else: |
| | chat_history_ids = bot_input_ids |
| | |
| | |
| | new_user_input_ids = self.current_tokenizer.encode( |
| | message + self.current_tokenizer.eos_token, |
| | return_tensors='pt' |
| | ) |
| | |
| | if chat_history_ids is not None: |
| | input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
| | else: |
| | input_ids = new_user_input_ids |
| | |
| | else: |
| | |
| | for user_msg, bot_msg in recent_history: |
| | if user_msg and bot_msg: |
| | conversation_text += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
| | |
| | conversation_text += f"User: {message}\nAssistant:" |
| | input_ids = self.current_tokenizer.encode(conversation_text, return_tensors='pt') |
| | |
| | |
| | if input_ids.shape[1] > self.max_input_length: |
| | input_ids = input_ids[:, -self.max_input_length:] |
| | |
| | |
| | streamer = TextIteratorStreamer( |
| | self.current_tokenizer, |
| | timeout=60.0, |
| | skip_prompt=True, |
| | skip_special_tokens=True |
| | ) |
| | |
| | generation_kwargs = { |
| | 'input_ids': input_ids, |
| | 'streamer': streamer, |
| | 'max_new_tokens': max_new_tokens, |
| | 'temperature': temperature, |
| | 'top_p': top_p, |
| | 'top_k': top_k, |
| | 'repetition_penalty': repetition_penalty, |
| | 'do_sample': True, |
| | 'pad_token_id': self.current_tokenizer.pad_token_id, |
| | 'eos_token_id': self.current_tokenizer.eos_token_id, |
| | 'no_repeat_ngram_size': 2, |
| | } |
| | |
| | |
| | generation_thread = Thread( |
| | target=self.current_model.generate, |
| | kwargs=generation_kwargs |
| | ) |
| | generation_thread.start() |
| | |
| | |
| | partial_response = "" |
| | for new_text in streamer: |
| | partial_response += new_text |
| | yield partial_response |
| | |
| | except Exception as e: |
| | yield f"β Generation error: {str(e)}" |
| |
|
| | def create_interface(): |
| | """Create the Gradio interface""" |
| | |
| | if not GRADIO_AVAILABLE: |
| | print("β Error: gradio library not installed. Run: pip install gradio") |
| | return None |
| | |
| | if not TRANSFORMERS_AVAILABLE: |
| | print("β Error: transformers library not installed. Run: pip install torch transformers") |
| | return None |
| | |
| | |
| | chat_system = CPULLMChat() |
| | |
| | |
| | css = """ |
| | .gradio-container { |
| | max-width: 1200px; |
| | margin: auto; |
| | } |
| | .chat-message { |
| | padding: 10px; |
| | margin: 5px 0; |
| | border-radius: 10px; |
| | } |
| | .user-message { |
| | background-color: #e3f2fd; |
| | margin-left: 20%; |
| | } |
| | .bot-message { |
| | background-color: #f1f8e9; |
| | margin-right: 20%; |
| | } |
| | """ |
| | |
| | with gr.Blocks(css=css, title="CPU LLM Chat") as demo: |
| | gr.Markdown("# π€ CPU-Optimized LLM Chat") |
| | gr.Markdown("*A lightweight chat interface for running language models on CPU*") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | model_dropdown = gr.Dropdown( |
| | choices=list(chat_system.models.keys()), |
| | value="microsoft/DialoGPT-medium", |
| | label="Select Model", |
| | info="Choose a model to load. DialoGPT models work best for chat." |
| | ) |
| | load_btn = gr.Button("π Load Model", variant="primary") |
| | model_status = gr.Textbox( |
| | label="Model Status", |
| | value="No model loaded", |
| | interactive=False |
| | ) |
| | |
| | with gr.Column(scale=1): |
| | gr.Markdown("### π‘ Model Info") |
| | gr.Markdown(""" |
| | - **DialoGPT Medium**: Best quality, slower |
| | - **DialoGPT Small**: Good balance |
| | - **DistilGPT2**: Fastest option |
| | - **GPT2**: General purpose |
| | - **BlenderBot**: Conversational AI |
| | """) |
| | |
| | |
| | chatbot = gr.Chatbot( |
| | label="Chat History", |
| | height=400, |
| | show_label=True, |
| | container=True |
| | ) |
| | |
| | with gr.Row(): |
| | msg = gr.Textbox( |
| | label="Your Message", |
| | placeholder="Type your message here... (Press Ctrl+Enter to send)", |
| | lines=3, |
| | max_lines=10, |
| | show_label=False |
| | ) |
| | send_btn = gr.Button("π€ Send", variant="primary") |
| | |
| | |
| | with gr.Accordion("βοΈ Generation Parameters", open=False): |
| | with gr.Row(): |
| | max_tokens = gr.Slider( |
| | minimum=50, |
| | maximum=512, |
| | value=256, |
| | step=10, |
| | label="Max New Tokens", |
| | info="Maximum number of tokens to generate" |
| | ) |
| | temperature = gr.Slider( |
| | minimum=0.1, |
| | maximum=2.0, |
| | value=0.7, |
| | step=0.1, |
| | label="Temperature", |
| | info="Higher values = more creative, lower = more focused" |
| | ) |
| | |
| | with gr.Row(): |
| | top_p = gr.Slider( |
| | minimum=0.1, |
| | maximum=1.0, |
| | value=0.9, |
| | step=0.05, |
| | label="Top-p", |
| | info="Nucleus sampling parameter" |
| | ) |
| | top_k = gr.Slider( |
| | minimum=1, |
| | maximum=100, |
| | value=50, |
| | step=1, |
| | label="Top-k", |
| | info="Top-k sampling parameter" |
| | ) |
| | repetition_penalty = gr.Slider( |
| | minimum=1.0, |
| | maximum=2.0, |
| | value=1.1, |
| | step=0.05, |
| | label="Repetition Penalty", |
| | info="Penalty for repeating tokens" |
| | ) |
| | |
| | |
| | with gr.Accordion("π¬ Example Messages", open=False): |
| | examples = [ |
| | "Hello! How are you today?", |
| | "Tell me a short story about a robot.", |
| | "What's the difference between AI and machine learning?", |
| | "Can you help me write a poem about nature?", |
| | "Explain quantum computing in simple terms.", |
| | ] |
| | |
| | example_buttons = [] |
| | for example in examples: |
| | btn = gr.Button(example, variant="secondary") |
| | example_buttons.append(btn) |
| | |
| | |
| | clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") |
| | |
| | |
| | def respond(message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty): |
| | if not chat_system.model_loaded: |
| | history.append([message, "β Please load a model first!"]) |
| | return history, "" |
| | |
| | history.append([message, ""]) |
| | |
| | for partial_response in chat_system.generate_response( |
| | message, history, max_new_tokens, temperature, top_p, top_k, repetition_penalty |
| | ): |
| | history[-1][1] = partial_response |
| | yield history, "" |
| | |
| | def load_model_handler(model_name, progress=gr.Progress()): |
| | return chat_system.load_model(model_name, progress) |
| | |
| | def set_example(example_text): |
| | return example_text |
| | |
| | def clear_chat(): |
| | return [], "" |
| | |
| | |
| | load_btn.click(load_model_handler, inputs=[model_dropdown], outputs=[model_status]) |
| | |
| | msg.submit(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
| | send_btn.click(respond, inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, msg]) |
| | |
| | clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
| | |
| | |
| | for btn, example in zip(example_buttons, examples): |
| | btn.click(set_example, inputs=[gr.State(example)], outputs=[msg]) |
| | |
| | |
| | gr.Markdown(""" |
| | --- |
| | ### π Instructions: |
| | 1. **Select and load a model** using the dropdown and "Load Model" button |
| | 2. **Wait for the model to load** (may take 1-2 minutes on first load) |
| | 3. **Start chatting** once you see "β
Successfully loaded" message |
| | 4. **Adjust parameters** if needed for different response styles |
| | |
| | ### π» System Requirements: |
| | - CPU with at least 4GB RAM available |
| | - Python 3.8+ with torch and transformers installed |
| | |
| | ### β‘ Performance Tips: |
| | - Use DialoGPT-small for fastest responses |
| | - Keep max tokens under 300 for better speed |
| | - Lower temperature (0.3-0.7) for more consistent responses |
| | """) |
| | |
| | return demo |
| |
|
| | def main(): |
| | """Main function to run the application""" |
| | |
| | print("===== CPU LLM Chat Application =====") |
| | print("Checking dependencies...") |
| | |
| | if not GRADIO_AVAILABLE: |
| | print("β Gradio not found. Install with: pip install gradio") |
| | return |
| | |
| | if not TRANSFORMERS_AVAILABLE: |
| | print("β Transformers not found. Install with: pip install torch transformers") |
| | return |
| | |
| | print("β
All dependencies found!") |
| | print("Starting web interface...") |
| | |
| | try: |
| | demo = create_interface() |
| | if demo: |
| | |
| | demo.queue(max_size=10).launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False, |
| | show_error=True, |
| | show_tips=True, |
| | inbrowser=False |
| | ) |
| | except KeyboardInterrupt: |
| | print("\nπ Application stopped by user") |
| | except Exception as e: |
| | print(f"β Error starting application: {e}") |
| |
|
| | if __name__ == "__main__": |
| | main() |