Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
| # --------------------------------------------------------- | |
| # Updated Model ID for the 60M version | |
| MODEL_ID = "Yangyang1205/MobileLLM-60M" | |
| # --------------------------------------------------------- | |
| model_loaded = False | |
| load_error_msg = "" | |
| def load_model(): | |
| global model_loaded, load_error_msg | |
| print(f"🚀 Starting... Loading model: {MODEL_ID}") | |
| try: | |
| # 1. Force Config Fix (Tie Weights) | |
| # Essential for MobileLLM architecture to prevent output gibberish | |
| config = AutoConfig.from_pretrained(MODEL_ID) | |
| config.tie_word_embeddings = True | |
| # 2. Load Model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| config=config, | |
| use_safetensors=False, # Kept as requested | |
| trust_remote_code=True | |
| ) | |
| model = model.to("cpu") | |
| model.eval() | |
| model_loaded = True | |
| return tokenizer, model | |
| except Exception as e: | |
| model_loaded = False | |
| load_error_msg = str(e) | |
| print(f"Error loading model: {e}") | |
| return None, None | |
| tokenizer, model = load_model() | |
| # --- Core Generation Function --- | |
| def generate_text(prompt, max_len, temp): | |
| if not model_loaded: | |
| return f"Model not loaded: {load_error_msg}" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_len, | |
| do_sample=False, # 必须设为 False 才能完全利用 Beam Search 的确定性 | |
| num_beams=3, # ✅ 加上这行:开启束搜索,寻找 3 条路径 | |
| repetition_penalty=1.0, | |
| pad_token_id=tokenizer.eos_token_id | |
| # 注意:我删掉了 temperature,因为 do_sample=False 时它不起作用 | |
| ) | |
| # Decode full response | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean result: remove the input prompt from the output | |
| new_text = full_response[len(prompt):] | |
| # Heuristic: Cut off incomplete sentences (stop at the first newline if present) | |
| if "\n" in new_text.strip(): | |
| lines = [line for line in new_text.split('\n') if line.strip()] | |
| if lines: | |
| return lines[0] | |
| return new_text | |
| except Exception as e: | |
| return str(e) | |
| # --- UI Layout (Blocks) --- | |
| # Note: No 'theme' or 'title' arguments in Blocks() to prevent Gradio version errors | |
| with gr.Blocks() as demo: | |
| # Header Section | |
| gr.Markdown( | |
| """ | |
| # 📱 MobileLLM 60M Demo | |
| This is a **60M parameter** base model trained on <1% of FineWeb data. | |
| It is a base model (not a chat model) and excels at **In-Context Learning**. | |
| """ | |
| ) | |
| # Split Layout | |
| with gr.Row(): | |
| # Left Column: Input | |
| with gr.Column(): | |
| input_box = gr.Textbox( | |
| label="Input Prompt", | |
| lines=10, | |
| placeholder="Enter text patterns here...", | |
| value="The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is" | |
| ) | |
| # Advanced Settings | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| slider_len = gr.Slider(minimum=1, maximum=100, value=20, label="Max Length", step=1) | |
| slider_temp = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, label="Temperature", step=0.1) | |
| submit_btn = gr.Button("🚀 Generate", variant="primary") | |
| # Right Column: Output | |
| with gr.Column(): | |
| output_box = gr.Textbox( | |
| label="Model Output", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Event Binding | |
| submit_btn.click( | |
| fn=generate_text, | |
| inputs=[input_box, slider_len, slider_temp], | |
| outputs=output_box | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["The capital of China is Beijing.\nThe capital of Japan is Tokyo.\nThe capital of Germany is Berlin.\nThe capital of France is"], | |
| ["Artificial Intelligence is a field of computer science that"], | |
| ["def add(a, b):\n return a + b\n\ndef multiply(a, b):"], | |
| ], | |
| inputs=input_box | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |