File size: 4,710 Bytes
66959df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e311af
 
 
 
 
 
66959df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

# ---------------------------------------------------------
# Updated Model ID for the 60M version
MODEL_ID = "Yangyang1205/MobileLLM-60M" 
# ---------------------------------------------------------

model_loaded = False
load_error_msg = ""

def load_model():
    global model_loaded, load_error_msg
    print(f"🚀 Starting... Loading model: {MODEL_ID}")
    try:
        # 1. Force Config Fix (Tie Weights)
        # Essential for MobileLLM architecture to prevent output gibberish
        config = AutoConfig.from_pretrained(MODEL_ID)
        config.tie_word_embeddings = True 
        
        # 2. Load Model
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, 
            config=config,
            use_safetensors=False, # Kept as requested
            trust_remote_code=True
        )
        
        model = model.to("cpu")
        model.eval()
        model_loaded = True
        return tokenizer, model
    except Exception as e:
        model_loaded = False
        load_error_msg = str(e)
        print(f"Error loading model: {e}")
        return None, None

tokenizer, model = load_model()

# --- Core Generation Function ---
def generate_text(prompt, max_len, temp):
    if not model_loaded:
        return f"Model not loaded: {load_error_msg}"

    try:
        inputs = tokenizer(prompt, return_tensors="pt")

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_len,
                do_sample=False,        # 必须设为 False 才能完全利用 Beam Search 的确定性
                num_beams=3,            # ✅ 加上这行:开启束搜索,寻找 3 条路径
                repetition_penalty=1.0,
                pad_token_id=tokenizer.eos_token_id 
                # 注意:我删掉了 temperature,因为 do_sample=False 时它不起作用
            )

        # Decode full response
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean result: remove the input prompt from the output
        new_text = full_response[len(prompt):]
        
        # Heuristic: Cut off incomplete sentences (stop at the first newline if present)
        if "\n" in new_text.strip():
             lines = [line for line in new_text.split('\n') if line.strip()]
             if lines:
                 return lines[0] 
        
        return new_text
        
    except Exception as e:
        return str(e)

# --- UI Layout (Blocks) ---
# Note: No 'theme' or 'title' arguments in Blocks() to prevent Gradio version errors
with gr.Blocks() as demo:
    
    # Header Section
    gr.Markdown(
        """
        # 📱 MobileLLM 60M Demo
        This is a **60M parameter** base model trained on <1% of FineWeb data.
        It is a base model (not a chat model) and excels at **In-Context Learning**.
        """
    )
    
    # Split Layout
    with gr.Row():
        # Left Column: Input
        with gr.Column():
            input_box = gr.Textbox(
                label="Input Prompt",
                lines=10, 
                placeholder="Enter text patterns here...",
                value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"
            )
            
            # Advanced Settings
            with gr.Accordion("⚙️ Advanced Settings", open=False):
                slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="Max Length", step=1)
                slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="Temperature", step=0.1)
            
            submit_btn = gr.Button("🚀 Generate", variant="primary")

        # Right Column: Output
        with gr.Column():
            output_box = gr.Textbox(
                label="Model Output",
                lines=10, 
                interactive=False 
            )

    # Event Binding
    submit_btn.click(
        fn=generate_text, 
        inputs=[input_box, slider_len, slider_temp], 
        outputs=output_box
    )

    # Examples
    gr.Examples(
        examples=[
            ["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"],
            ["Artificial Intelligence is a field of computer science that"],
            ["def add(a, b):\n    return a + b\n\ndef multiply(a, b):"],
        ],
        inputs=input_box
    )

if __name__ == "__main__":
    demo.launch()