Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| class LLMModule: | |
| def __init__(self): | |
| self.model_options = { | |
| "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Phi-2": "microsoft/phi-2", | |
| "Qwen 0.5B": "Qwen/Qwen2.5-0.5B-Instruct" | |
| } | |
| self.current_model = None | |
| self.pipe = None | |
| self.chat_history = [] | |
| def load_model(self, model_name): | |
| """Load LLM model""" | |
| try: | |
| model_id = self.model_options[model_name] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| device=device, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ) | |
| self.current_model = model_name | |
| self.chat_history = [] | |
| return f"β Loaded {model_name} on {device}" | |
| except Exception as e: | |
| return f"β Error loading model: {str(e)}" | |
| def generate_response(self, message, max_tokens, temperature): | |
| """Generate LLM response""" | |
| if self.pipe is None: | |
| return "β Please load a model first", [] | |
| if not message.strip(): | |
| return "β Please enter a message", self.chat_history | |
| try: | |
| # Add user message to history | |
| self.chat_history.append({"role": "user", "content": message}) | |
| # Generate response | |
| response = self.pipe( | |
| message, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| do_sample=True, | |
| top_p=0.9 | |
| ) | |
| assistant_message = response[0]["generated_text"] | |
| # Clean up if the model repeats the input | |
| if assistant_message.startswith(message): | |
| assistant_message = assistant_message[len(message):].strip() | |
| # Add assistant response to history | |
| self.chat_history.append({"role": "assistant", "content": assistant_message}) | |
| # Format for chatbot display | |
| chat_display = [(h["content"], self.chat_history[i+1]["content"]) | |
| for i, h in enumerate(self.chat_history[::2]) | |
| if i*2+1 < len(self.chat_history)] | |
| return "", chat_display | |
| except Exception as e: | |
| return f"β Error generating response: {str(e)}", self.chat_history | |
| def clear_history(self): | |
| """Clear chat history""" | |
| self.chat_history = [] | |
| return [], "" | |
| def create_interface(self): | |
| """Create Gradio interface for LLM testing""" | |
| with gr.Column() as interface: | |
| gr.Markdown("## π€ LLM Testing") | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(self.model_options.keys()), | |
| value="Qwen 0.5B", | |
| label="Select LLM Model" | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### Chat Interface") | |
| chatbot = gr.Chatbot(label="Conversation", height=400) | |
| with gr.Row(): | |
| message_input = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message...", | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="secondary", scale=1) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=150, | |
| step=10, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| load_btn.click( | |
| fn=self.load_model, | |
| inputs=[model_selector], | |
| outputs=[status] | |
| ) | |
| send_btn.click( | |
| fn=self.generate_response, | |
| inputs=[message_input, max_tokens, temperature], | |
| outputs=[message_input, chatbot] | |
| ) | |
| message_input.submit( | |
| fn=self.generate_response, | |
| inputs=[message_input, max_tokens, temperature], | |
| outputs=[message_input, chatbot] | |
| ) | |
| clear_btn.click( | |
| fn=self.clear_history, | |
| outputs=[chatbot, message_input] | |
| ) | |
| return interface | |