import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # --------------------------------------------------------- # Updated Model ID for the 60M version MODEL_ID = "Yangyang1205/MobileLLM-60M" # --------------------------------------------------------- model_loaded = False load_error_msg = "" def load_model(): global model_loaded, load_error_msg print(f"🚀 Starting... Loading model: {MODEL_ID}") try: # 1. Force Config Fix (Tie Weights) # Essential for MobileLLM architecture to prevent output gibberish config = AutoConfig.from_pretrained(MODEL_ID) config.tie_word_embeddings = True # 2. Load Model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=config, use_safetensors=False, # Kept as requested trust_remote_code=True ) model = model.to("cpu") model.eval() model_loaded = True return tokenizer, model except Exception as e: model_loaded = False load_error_msg = str(e) print(f"Error loading model: {e}") return None, None tokenizer, model = load_model() # --- Core Generation Function --- def generate_text(prompt, max_len, temp): if not model_loaded: return f"Model not loaded: {load_error_msg}" try: inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_len, do_sample=False, # 必须设为 False 才能完全利用 Beam Search 的确定性 num_beams=3, # ✅ 加上这行:开启束搜索,寻找 3 条路径 repetition_penalty=1.0, pad_token_id=tokenizer.eos_token_id # 注意:我删掉了 temperature,因为 do_sample=False 时它不起作用 ) # Decode full response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean result: remove the input prompt from the output new_text = full_response[len(prompt):] # Heuristic: Cut off incomplete sentences (stop at the first newline if present) if "\n" in new_text.strip(): lines = [line for line in new_text.split('\n') if line.strip()] if lines: return lines[0] return new_text except Exception as e: return str(e) # --- UI Layout (Blocks) --- # Note: No 'theme' or 'title' arguments in Blocks() to prevent Gradio version errors with gr.Blocks() as demo: # Header Section gr.Markdown( """ # 📱 MobileLLM 60M Demo This is a **60M parameter** base model trained on <1% of FineWeb data. It is a base model (not a chat model) and excels at **In-Context Learning**. """ ) # Split Layout with gr.Row(): # Left Column: Input with gr.Column(): input_box = gr.Textbox( label="Input Prompt", lines=10, placeholder="Enter text patterns here...", value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is" ) # Advanced Settings with gr.Accordion("⚙️ Advanced Settings", open=False): slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="Max Length", step=1) slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="Temperature", step=0.1) submit_btn = gr.Button("🚀 Generate", variant="primary") # Right Column: Output with gr.Column(): output_box = gr.Textbox( label="Model Output", lines=10, interactive=False ) # Event Binding submit_btn.click( fn=generate_text, inputs=[input_box, slider_len, slider_temp], outputs=output_box ) # Examples gr.Examples( examples=[ ["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"], ["Artificial Intelligence is a field of computer science that"], ["def add(a, b):\n return a + b\n\ndef multiply(a, b):"], ], inputs=input_box ) if __name__ == "__main__": demo.launch()