Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import spaces | |
| # Model configuration | |
| MODEL_ID = "WeiboAI/VibeThinker-1.5B" | |
| SYSTEM_PROMPT = "You are a concise solver. Respond briefly." | |
| # Load model and tokenizer | |
| def load_model(): | |
| """Load the model and tokenizer""" | |
| try: | |
| print(f"Loading model: {MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| print("Model loaded successfully!") | |
| return model, tokenizer | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # Initialize model and tokenizer | |
| try: | |
| model, tokenizer = load_model() | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| model = None | |
| tokenizer = None | |
| def chat_response(message, history): | |
| """ | |
| Generate response for the chat interface. | |
| Args: | |
| message (str): Current user message | |
| history (list): Chat history as list of tuples [(user_msg, assistant_msg), ...] | |
| Returns: | |
| str: Generated response | |
| """ | |
| if model is None or tokenizer is None: | |
| return "Model not loaded. Please check the model configuration." | |
| try: | |
| # Build conversation format | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| # Add chat history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| formatted_input = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| model_inputs = tokenizer([formatted_input], return_tensors="pt").to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| generated_ids = [ | |
| output_ids[len(input_ids):] | |
| for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return response.strip() | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| return f"Sorry, I encountered an error: {str(e)}" | |
| def create_demo(): | |
| """Create the Gradio chat interface""" | |
| # Create chat interface | |
| demo = gr.ChatInterface( | |
| fn=chat_response, | |
| title="VibeThinker-1.5B Chat", | |
| description=f"Chat with {MODEL_ID}. {SYSTEM_PROMPT}", | |
| examples=[ | |
| "What is 2+2?", | |
| "Explain quantum physics briefly", | |
| "Write a short poem", | |
| "How do I make good decisions?" | |
| ], | |
| theme=gr.themes.Soft(), | |
| show_progress="minimal", | |
| retry_btn="π Retry", | |
| undo_btn="β©οΈ Undo", | |
| clear_btn="ποΈ Clear", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(share=False) |