Spaces:
Sleeping
Sleeping
File size: 4,710 Bytes
66959df 3e311af 66959df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# ---------------------------------------------------------
# Updated Model ID for the 60M version
MODEL_ID = "Yangyang1205/MobileLLM-60M"
# ---------------------------------------------------------
model_loaded = False
load_error_msg = ""
def load_model():
global model_loaded, load_error_msg
print(f"🚀 Starting... Loading model: {MODEL_ID}")
try:
# 1. Force Config Fix (Tie Weights)
# Essential for MobileLLM architecture to prevent output gibberish
config = AutoConfig.from_pretrained(MODEL_ID)
config.tie_word_embeddings = True
# 2. Load Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=config,
use_safetensors=False, # Kept as requested
trust_remote_code=True
)
model = model.to("cpu")
model.eval()
model_loaded = True
return tokenizer, model
except Exception as e:
model_loaded = False
load_error_msg = str(e)
print(f"Error loading model: {e}")
return None, None
tokenizer, model = load_model()
# --- Core Generation Function ---
def generate_text(prompt, max_len, temp):
if not model_loaded:
return f"Model not loaded: {load_error_msg}"
try:
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_len,
do_sample=False, # 必须设为 False 才能完全利用 Beam Search 的确定性
num_beams=3, # ✅ 加上这行:开启束搜索,寻找 3 条路径
repetition_penalty=1.0,
pad_token_id=tokenizer.eos_token_id
# 注意:我删掉了 temperature,因为 do_sample=False 时它不起作用
)
# Decode full response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean result: remove the input prompt from the output
new_text = full_response[len(prompt):]
# Heuristic: Cut off incomplete sentences (stop at the first newline if present)
if "\n" in new_text.strip():
lines = [line for line in new_text.split('\n') if line.strip()]
if lines:
return lines[0]
return new_text
except Exception as e:
return str(e)
# --- UI Layout (Blocks) ---
# Note: No 'theme' or 'title' arguments in Blocks() to prevent Gradio version errors
with gr.Blocks() as demo:
# Header Section
gr.Markdown(
"""
# 📱 MobileLLM 60M Demo
This is a **60M parameter** base model trained on <1% of FineWeb data.
It is a base model (not a chat model) and excels at **In-Context Learning**.
"""
)
# Split Layout
with gr.Row():
# Left Column: Input
with gr.Column():
input_box = gr.Textbox(
label="Input Prompt",
lines=10,
placeholder="Enter text patterns here...",
value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"
)
# Advanced Settings
with gr.Accordion("⚙️ Advanced Settings", open=False):
slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="Max Length", step=1)
slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="Temperature", step=0.1)
submit_btn = gr.Button("🚀 Generate", variant="primary")
# Right Column: Output
with gr.Column():
output_box = gr.Textbox(
label="Model Output",
lines=10,
interactive=False
)
# Event Binding
submit_btn.click(
fn=generate_text,
inputs=[input_box, slider_len, slider_temp],
outputs=output_box
)
# Examples
gr.Examples(
examples=[
["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"],
["Artificial Intelligence is a field of computer science that"],
["def add(a, b):\n return a + b\n\ndef multiply(a, b):"],
],
inputs=input_box
)
if __name__ == "__main__":
demo.launch() |