File size: 3,630 Bytes
8fca131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a65c0
 
8fca131
 
84a65c0
8fca131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

# Model configuration
MODEL_ID = "WeiboAI/VibeThinker-1.5B"
SYSTEM_PROMPT = "You are a concise solver. Respond briefly."

# Load model and tokenizer
def load_model():
    """Load the model and tokenizer"""
    try:
        print(f"Loading model: {MODEL_ID}")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        print("Model loaded successfully!")
        return model, tokenizer
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

# Initialize model and tokenizer
try:
    model, tokenizer = load_model()
except Exception as e:
    print(f"Failed to load model: {e}")
    model = None
    tokenizer = None

@spaces.GPU
def chat_response(message, history):
    """
    Generate response for the chat interface.
    
    Args:
        message (str): Current user message
        history (list): Chat history as list of tuples [(user_msg, assistant_msg), ...]
    
    Returns:
        str: Generated response
    """
    if model is None or tokenizer is None:
        return "Model not loaded. Please check the model configuration."
    
    try:
        # Build conversation format
        messages = [{"role": "system", "content": SYSTEM_PROMPT}]
        
        # Add chat history
        for user_msg, assistant_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": assistant_msg})
        
        # Add current message
        messages.append({"role": "user", "content": message})
        
        # Apply chat template
        formatted_input = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Tokenize input
        model_inputs = tokenizer([formatted_input], return_tensors="pt").to(model.device)
        
        # Generate response
        with torch.no_grad():
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode response
        generated_ids = [
            output_ids[len(input_ids):] 
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return response.strip()
        
    except Exception as e:
        print(f"Error generating response: {e}")
        return f"Sorry, I encountered an error: {str(e)}"

def create_demo():
    """Create the Gradio chat interface"""
    
    # Create chat interface
    demo = gr.ChatInterface(
        fn=chat_response,
        title="VibeThinker-1.5B Chat",
        description=f"Chat with {MODEL_ID}. {SYSTEM_PROMPT}",
        examples=[
            "What is 2+2?",
            "Explain quantum physics briefly",
            "Write a short poem",
            "How do I make good decisions?"
        ],
        theme=gr.themes.Soft(),
        show_progress="minimal",
        retry_btn="🔄 Retry",
        undo_btn="↩️ Undo",
        clear_btn="🗑️ Clear",
    )
    
    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=False)