Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import yaml | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO_ID = "LiManshu/nextShakespeare" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _MODEL = None | |
| _TOKENIZER = None | |
| _MODEL_CFG = None | |
| def _load_yaml(path: Path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| def _ensure_loaded(): | |
| global _MODEL, _TOKENIZER, _MODEL_CFG | |
| if _MODEL is not None and _TOKENIZER is not None: | |
| return | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| repo_type="model", | |
| allow_patterns=[ | |
| "llm/**", | |
| "configs/model.yaml", | |
| "data/vocab/char_vocab.json", | |
| "checkpoints/best_model.pt", | |
| ], | |
| ) | |
| repo_root = Path(local_dir) | |
| sys.path.insert(0, str(repo_root)) | |
| from llm.data.tokenizer import CharTokenizer | |
| from llm.inference.generate import greedy_decode, sample_decode | |
| from llm.model.transformer import Transformer | |
| from llm.utils.checkpoint import load_model_only | |
| model_cfg = _load_yaml(repo_root / "configs" / "model.yaml") | |
| tokenizer = CharTokenizer(vocab_path=str(repo_root / "data" / "vocab" / "char_vocab.json")) | |
| model_cfg["vocab_size"] = tokenizer.vocab_size | |
| model = Transformer(model_cfg) | |
| load_model_only(model, str(repo_root / "checkpoints" / "best_model.pt")) | |
| model.to(DEVICE) | |
| model.eval() | |
| _MODEL = model | |
| _TOKENIZER = tokenizer | |
| _MODEL_CFG = { | |
| "hidden_size": model_cfg.get("hidden_size"), | |
| "num_hidden_layers": model_cfg.get("num_hidden_layers"), | |
| "vocab_size": tokenizer.vocab_size, | |
| "device": str(DEVICE), | |
| } | |
| _MODEL.greedy_decode = greedy_decode | |
| _MODEL.sample_decode = sample_decode | |
| def generate(prompt, max_length, temperature, top_k, top_p): | |
| _ensure_loaded() | |
| text = prompt or "" | |
| input_ids = _TOKENIZER.encode(text) | |
| if not input_ids: | |
| input_ids = [0] | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| with torch.no_grad(): | |
| if float(temperature) == 0.0: | |
| generated_ids = _MODEL.greedy_decode( | |
| _MODEL, | |
| input_ids, | |
| max_length=int(max_length), | |
| device=DEVICE, | |
| ) | |
| else: | |
| generated_ids = _MODEL.sample_decode( | |
| _MODEL, | |
| input_ids, | |
| max_length=int(max_length), | |
| temperature=float(temperature), | |
| top_k=int(top_k) if int(top_k) > 0 else None, | |
| top_p=float(top_p) if float(top_p) > 0 else None, | |
| device=DEVICE, | |
| ) | |
| output_text = _TOKENIZER.decode(generated_ids[0]) | |
| info = ( | |
| f"repo={MODEL_REPO_ID} | device={_MODEL_CFG['device']} | " | |
| f"layers={_MODEL_CFG['num_hidden_layers']} | hidden={_MODEL_CFG['hidden_size']}" | |
| ) | |
| return output_text, info | |
| demo = gr.Interface( | |
| fn=generate, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", value="First Citizen:\n", lines=8), | |
| gr.Slider(minimum=1, maximum=400, value=200, step=1, label="max_length"), | |
| gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="temperature"), | |
| gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="top_p"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Generated Text", lines=14), | |
| gr.Textbox(label="Model Info"), | |
| ], | |
| title="manshu-init", | |
| description="Online demo for LiManshu/nextShakespeare", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |