Spaces:
Sleeping
Sleeping
| """ | |
| app.py — Gradio web demo for MT-LNN (Hugging Face Spaces). | |
| Loads a base causal-LM from the Hub (default: Qwen2.5-0.5B-Instruct, supports | |
| Chinese + English) and optionally applies a saved MT-LNN adapter checkpoint. | |
| On free-CPU Spaces the model runs in fp32; on GPU it switches to bfloat16. | |
| Environment variables (set in Space Settings → Variables): | |
| BASE_MODEL HF model-id to load (default: Qwen/Qwen2.5-0.5B-Instruct) | |
| ADAPTER_PATH local path or HF path to an MT-LNN adapter .pt (optional) | |
| """ | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") | |
| ADAPTER_PATH = os.environ.get("ADAPTER_PATH", "") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = (torch.bfloat16 | |
| if DEVICE == "cuda" and torch.cuda.is_bf16_supported() | |
| else torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Model loading (once at startup) | |
| # --------------------------------------------------------------------------- | |
| print(f"[MT-LNN] Loading {BASE_MODEL} on {DEVICE} ({DTYPE}) …") | |
| _tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| if _tokenizer.pad_token is None: | |
| _tokenizer.pad_token = _tokenizer.eos_token | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=DTYPE, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if ADAPTER_PATH and os.path.isfile(ADAPTER_PATH): | |
| try: | |
| from mt_lnn.llama_adapter import attach_adapters_from_checkpoint, load_adapter_state | |
| checkpoint = torch.load(ADAPTER_PATH, map_location="cpu") | |
| attach_adapters_from_checkpoint(_model, checkpoint) | |
| load_adapter_state(_model, ADAPTER_PATH, strict=False) | |
| print(f"[MT-LNN] Adapter loaded from {ADAPTER_PATH}") | |
| except Exception as exc: | |
| print(f"[MT-LNN] WARNING: could not load adapter — {exc}") | |
| if DEVICE == "cpu": | |
| _model = _model.to(DEVICE) | |
| _model.eval() | |
| print("[MT-LNN] Model ready.") | |
| # --------------------------------------------------------------------------- | |
| # Sampling helpers | |
| # --------------------------------------------------------------------------- | |
| def _top_k(logits: torch.Tensor, k: int) -> torch.Tensor: | |
| if k <= 0: | |
| return logits | |
| v, _ = torch.topk(logits, min(k, logits.size(-1))) | |
| return logits.masked_fill(logits < v[:, [-1]], float("-inf")) | |
| def _top_p(logits: torch.Tensor, p: float) -> torch.Tensor: | |
| if p >= 1.0: | |
| return logits | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) | |
| probs = F.softmax(sorted_logits, dim=-1) | |
| keep = probs.cumsum(dim=-1) <= p | |
| keep[..., 0] = True | |
| mask = torch.zeros_like(logits, dtype=torch.bool) | |
| mask.scatter_(-1, sorted_idx, keep) | |
| return logits.masked_fill(~mask, float("-inf")) | |
| def _build_prompt(history: list, message: str) -> str: | |
| """Build a chat prompt using apply_chat_template when available.""" | |
| messages = [] | |
| for user_msg, bot_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| if getattr(_tokenizer, "chat_template", None): | |
| return _tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Fallback for models without a chat template | |
| prompt = "" | |
| for user_msg, bot_msg in history: | |
| prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{bot_msg}\n" | |
| prompt += f"<|user|>\n{message}\n<|assistant|>\n" | |
| return prompt | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def generate_text( | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| ) -> str: | |
| ids = _tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) | |
| prompt_len = ids.shape[1] | |
| eos_id = _tokenizer.eos_token_id | |
| generated_ids = ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| out = _model(input_ids=generated_ids) | |
| logits = out.logits[:, -1, :] / max(float(temperature), 1e-6) | |
| logits = _top_k(logits, int(top_k)) | |
| logits = _top_p(logits, float(top_p)) | |
| next_id = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
| generated_ids = torch.cat([generated_ids, next_id], dim=1) | |
| if eos_id is not None and next_id.item() == eos_id: | |
| break | |
| # Decode only the newly generated tokens to avoid space/encoding issues | |
| new_tokens = generated_ids[0, prompt_len:] | |
| return _tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| def chat_stream( | |
| message: str, | |
| history: list, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| ): | |
| prompt = _build_prompt(history, message) | |
| ids = _tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) | |
| prompt_len = ids.shape[1] | |
| eos_id = _tokenizer.eos_token_id | |
| generated_ids = ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| out = _model(input_ids=generated_ids) | |
| logits = out.logits[:, -1, :] / max(float(temperature), 1e-6) | |
| logits = _top_k(logits, int(top_k)) | |
| logits = _top_p(logits, float(top_p)) | |
| next_id = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
| generated_ids = torch.cat([generated_ids, next_id], dim=1) | |
| # Decode ALL new tokens together — fixes SentencePiece space-prefix loss | |
| new_tokens = generated_ids[0, prompt_len:] | |
| yield _tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| if eos_id is not None and next_id.item() == eos_id: | |
| break | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| _adapter_badge = ( | |
| f"🧠 **MT-LNN adapter active** (`{os.path.basename(ADAPTER_PATH)}`)" | |
| if ADAPTER_PATH and os.path.isfile(ADAPTER_PATH) | |
| else "⚙️ Running vanilla base model (no MT-LNN adapter)" | |
| ) | |
| _description = f""" | |
| ## MT-LNN — Microtubule Linear Neural Network | |
| **Base model:** `{BASE_MODEL}` | **Device:** `{DEVICE}` | |
| {_adapter_badge} | |
| This demo showcases the [MT-LNN architecture](https://huggingface.co/EverestAn/MT-LNN): | |
| a biologically-inspired hybrid that couples a standard transformer with a linear | |
| recurrent network modelling microtubule quantum-coherence dynamics. | |
| 支持中英文对话 · Bilingual (Chinese & English) · Type below and hit **Submit**. | |
| """ | |
| with gr.Blocks(title="MT-LNN Demo") as demo: | |
| gr.Markdown(_description) | |
| with gr.Tab("💬 Chat"): | |
| gr.ChatInterface( | |
| fn=chat_stream, | |
| additional_inputs=[ | |
| gr.Slider(32, 512, value=200, step=32, label="Max new tokens"), | |
| gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature"), | |
| gr.Slider(0, 100, value=0, step=1, label="Top-k (0 = off)"), | |
| gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"), | |
| ], | |
| ) | |
| with gr.Tab("📝 Completion"): | |
| prompt_box = gr.Textbox( | |
| lines=5, placeholder="Enter a prompt… / 输入提示词…", label="Prompt" | |
| ) | |
| with gr.Row(): | |
| max_tok = gr.Slider(32, 512, value=200, step=32, label="Max new tokens") | |
| temp = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_k_sl = gr.Slider(0, 100, value=0, step=1, label="Top-k (0 = off)") | |
| top_p_sl = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| output_box = gr.Textbox(lines=10, label="Generated text", interactive=False) | |
| run_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_box, max_tok, temp, top_k_sl, top_p_sl], | |
| outputs=output_box, | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "Model weights & code: [EverestAn/MT-LNN](https://huggingface.co/EverestAn/MT-LNN) · " | |
| "MIT license" | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft(), | |
| ssr_mode=False, | |
| ) | |