MobileLLM-60M / app.py
Yangyang1205's picture
Update app.py
3e311af verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# ---------------------------------------------------------
# Updated Model ID for the 60M version
MODEL_ID = "Yangyang1205/MobileLLM-60M"
# ---------------------------------------------------------
model_loaded = False
load_error_msg = ""
def load_model():
global model_loaded, load_error_msg
print(f"🚀 Starting... Loading model: {MODEL_ID}")
try:
# 1. Force Config Fix (Tie Weights)
# Essential for MobileLLM architecture to prevent output gibberish
config = AutoConfig.from_pretrained(MODEL_ID)
config.tie_word_embeddings = True
# 2. Load Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=config,
use_safetensors=False, # Kept as requested
trust_remote_code=True
)
model = model.to("cpu")
model.eval()
model_loaded = True
return tokenizer, model
except Exception as e:
model_loaded = False
load_error_msg = str(e)
print(f"Error loading model: {e}")
return None, None
tokenizer, model = load_model()
# --- Core Generation Function ---
def generate_text(prompt, max_len, temp):
if not model_loaded:
return f"Model not loaded: {load_error_msg}"
try:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_len,
do_sample=False, # 必须设为 False 才能完全利用 Beam Search 的确定性
num_beams=3, # ✅ 加上这行:开启束搜索,寻找 3 条路径
repetition_penalty=1.0,
pad_token_id=tokenizer.eos_token_id
# 注意:我删掉了 temperature,因为 do_sample=False 时它不起作用
)
# Decode full response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean result: remove the input prompt from the output
new_text = full_response[len(prompt):]
# Heuristic: Cut off incomplete sentences (stop at the first newline if present)
if "\n" in new_text.strip():
lines = [line for line in new_text.split('\n') if line.strip()]
if lines:
return lines[0]
return new_text
except Exception as e:
return str(e)
# --- UI Layout (Blocks) ---
# Note: No 'theme' or 'title' arguments in Blocks() to prevent Gradio version errors
with gr.Blocks() as demo:
# Header Section
gr.Markdown(
"""
# 📱 MobileLLM 60M Demo
This is a **60M parameter** base model trained on <1% of FineWeb data.
It is a base model (not a chat model) and excels at **In-Context Learning**.
"""
)
# Split Layout
with gr.Row():
# Left Column: Input
with gr.Column():
input_box = gr.Textbox(
label="Input Prompt",
lines=10,
placeholder="Enter text patterns here...",
value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"
)
# Advanced Settings
with gr.Accordion("⚙️ Advanced Settings", open=False):
slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="Max Length", step=1)
slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="Temperature", step=0.1)
submit_btn = gr.Button("🚀 Generate", variant="primary")
# Right Column: Output
with gr.Column():
output_box = gr.Textbox(
label="Model Output",
lines=10,
interactive=False
)
# Event Binding
submit_btn.click(
fn=generate_text,
inputs=[input_box, slider_len, slider_temp],
outputs=output_box
)
# Examples
gr.Examples(
examples=[
["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"],
["Artificial Intelligence is a field of computer science that"],
["def add(a, b):\n return a + b\n\ndef multiply(a, b):"],
],
inputs=input_box
)
if __name__ == "__main__":
demo.launch()