Spaces:
Running
Running
| import gradio as gr | |
| import tarfile | |
| import os | |
| import shutil | |
| import glob | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config | |
| UPLOAD_DIR = "uploaded_models" | |
| ACTIVE_MODEL_DIR = "active_model" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # Global model state | |
| loaded_model = None | |
| loaded_tokenizer = None | |
| loaded_model_name = None | |
| def list_saved_models(): | |
| tars = glob.glob(os.path.join(UPLOAD_DIR, "*.tar*")) | |
| return [os.path.basename(t) for t in tars] if tars else ["(none)"] | |
| def upload_and_extract(tar_file): | |
| if tar_file is None: | |
| return "no file uploaded.", gr.update(choices=list_saved_models()) | |
| filename = os.path.basename(tar_file.name) | |
| dest = os.path.join(UPLOAD_DIR, filename) | |
| shutil.copy(tar_file.name, dest) | |
| # peek inside for .ckpt files | |
| try: | |
| with tarfile.open(dest, "r:*") as tf: | |
| members = tf.getnames() | |
| except Exception as e: | |
| return f"failed to open tar: {e}", gr.update(choices=list_saved_models()) | |
| ckpts = [m for m in members if m.endswith(".ckpt")] | |
| ckpt_msg = f"found {len(ckpts)} .ckpt file(s): {ckpts}" if ckpts else "no .ckpt files found (still saved)" | |
| return f"saved as `{filename}`. {ckpt_msg}", gr.update(choices=list_saved_models()) | |
| def load_model(model_tar_name): | |
| global loaded_model, loaded_tokenizer, loaded_model_name | |
| if not model_tar_name or model_tar_name == "(none)": | |
| return "select a model first." | |
| tar_path = os.path.join(UPLOAD_DIR, model_tar_name) | |
| if not os.path.exists(tar_path): | |
| return f"tar not found: {tar_path}" | |
| # clean and re-extract | |
| if os.path.exists(ACTIVE_MODEL_DIR): | |
| shutil.rmtree(ACTIVE_MODEL_DIR) | |
| os.makedirs(ACTIVE_MODEL_DIR) | |
| try: | |
| with tarfile.open(tar_path, "r:*") as tf: | |
| tf.extractall(ACTIVE_MODEL_DIR) | |
| except Exception as e: | |
| return f"extraction failed: {e}" | |
| # find .ckpt files | |
| ckpts = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "*.ckpt"), recursive=True) | |
| # try loading as HF model first (config.json present), else fall back to ckpt | |
| hf_configs = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "config.json"), recursive=True) | |
| try: | |
| if hf_configs: | |
| model_dir = os.path.dirname(hf_configs[0]) | |
| loaded_tokenizer = GPT2Tokenizer.from_pretrained(model_dir) | |
| loaded_model = GPT2LMHeadModel.from_pretrained(model_dir) | |
| loaded_model.eval() | |
| loaded_model_name = model_tar_name | |
| return f"loaded HF-format model from `{model_dir}`. ckpts present: {[os.path.basename(c) for c in ckpts]}" | |
| elif ckpts: | |
| # bare .ckpt — assume state_dict for gpt2 base config | |
| ckpt_path = ckpts[0] | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| # handle common wrapper keys | |
| if isinstance(state, dict): | |
| if "state_dict" in state: | |
| state = state["state_dict"] | |
| elif "model" in state: | |
| state = state["model"] | |
| config = GPT2Config() | |
| loaded_model = GPT2LMHeadModel(config) | |
| loaded_model.load_state_dict(state, strict=False) | |
| loaded_model.eval() | |
| loaded_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| loaded_model_name = model_tar_name | |
| return f"loaded ckpt `{os.path.basename(ckpt_path)}` (strict=False, base gpt2 config)" | |
| else: | |
| return "no config.json or .ckpt found in tar — can't load." | |
| except Exception as e: | |
| return f"model load failed: {e}" | |
| def generate_text(prompt, max_new_tokens, temperature, top_p, top_k): | |
| global loaded_model, loaded_tokenizer, loaded_model_name | |
| if loaded_model is None: | |
| return "no model loaded. upload and load one first." | |
| if not prompt.strip(): | |
| return "gimme a prompt." | |
| inputs = loaded_tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"] | |
| with torch.no_grad(): | |
| output = loaded_model.generate( | |
| input_ids, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| do_sample=True, | |
| pad_token_id=loaded_tokenizer.eos_token_id, | |
| ) | |
| generated = loaded_tokenizer.decode(output[0], skip_special_tokens=True) | |
| return generated | |
| # --- UI --- | |
| with gr.Blocks(title="gpt-2 model manager") as demo: | |
| gr.Markdown("## gpt-2 model uploader & generator\nupload a `.tar` or `.tar.gz` containing your model. supports HF-format dirs or bare `.ckpt` files.") | |
| with gr.Tab("upload & manage"): | |
| with gr.Row(): | |
| tar_input = gr.File(label="upload .tar / .tar.gz", file_types=[".tar", ".gz", ".tar.gz"]) | |
| upload_btn = gr.Button("upload & save") | |
| upload_status = gr.Textbox(label="status", interactive=False) | |
| gr.Markdown("---") | |
| model_dropdown = gr.Dropdown(label="saved models", choices=list_saved_models(), interactive=True) | |
| load_btn = gr.Button("load selected model") | |
| load_status = gr.Textbox(label="load status", interactive=False) | |
| upload_btn.click( | |
| upload_and_extract, | |
| inputs=[tar_input], | |
| outputs=[upload_status, model_dropdown], | |
| ) | |
| load_btn.click( | |
| load_model, | |
| inputs=[model_dropdown], | |
| outputs=[load_status], | |
| ) | |
| with gr.Tab("generate"): | |
| prompt_input = gr.Textbox(label="prompt", lines=3, placeholder="enter your prompt here...") | |
| with gr.Row(): | |
| max_tokens = gr.Slider(10, 500, value=100, step=10, label="max new tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature") | |
| with gr.Row(): | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top-p") | |
| top_k = gr.Slider(0, 100, value=50, step=1, label="top-k") | |
| gen_btn = gr.Button("generate") | |
| output_text = gr.Textbox(label="output", lines=10, interactive=False) | |
| gen_btn.click( | |
| generate_text, | |
| inputs=[prompt_input, max_tokens, temperature, top_p, top_k], | |
| outputs=[output_text], | |
| ) | |
| demo.launch() | |