import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces # Model configuration MODEL_ID = "WeiboAI/VibeThinker-1.5B" SYSTEM_PROMPT = "You are a concise solver. Respond briefly." # Load model and tokenizer def load_model(): """Load the model and tokenizer""" try: print(f"Loading model: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto", ) print("Model loaded successfully!") return model, tokenizer except Exception as e: print(f"Error loading model: {e}") raise # Initialize model and tokenizer try: model, tokenizer = load_model() except Exception as e: print(f"Failed to load model: {e}") model = None tokenizer = None @spaces.GPU def chat_response(message, history): """ Generate response for the chat interface. Args: message (str): Current user message history (list): Chat history as list of tuples [(user_msg, assistant_msg), ...] Returns: str: Generated response """ if model is None or tokenizer is None: return "Model not loaded. Please check the model configuration." try: # Build conversation format messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Add chat history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template formatted_input = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input model_inputs = tokenizer([formatted_input], return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): generated_ids = model.generate( **model_inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode response generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response.strip() except Exception as e: print(f"Error generating response: {e}") return f"Sorry, I encountered an error: {str(e)}" def create_demo(): """Create the Gradio chat interface""" # Create chat interface demo = gr.ChatInterface( fn=chat_response, title="VibeThinker-1.5B Chat", description=f"Chat with {MODEL_ID}. {SYSTEM_PROMPT}", examples=[ "What is 2+2?", "Explain quantum physics briefly", "Write a short poem", "How do I make good decisions?" ], theme=gr.themes.Soft(), show_progress="minimal", retry_btn="🔄 Retry", undo_btn="â†Šī¸ Undo", clear_btn="đŸ—‘ī¸ Clear", ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=False)