import gradio as gr import tarfile import os import shutil import glob import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config UPLOAD_DIR = "uploaded_models" ACTIVE_MODEL_DIR = "active_model" os.makedirs(UPLOAD_DIR, exist_ok=True) # Global model state loaded_model = None loaded_tokenizer = None loaded_model_name = None def list_saved_models(): tars = glob.glob(os.path.join(UPLOAD_DIR, "*.tar*")) return [os.path.basename(t) for t in tars] if tars else ["(none)"] def upload_and_extract(tar_file): if tar_file is None: return "no file uploaded.", gr.update(choices=list_saved_models()) filename = os.path.basename(tar_file.name) dest = os.path.join(UPLOAD_DIR, filename) shutil.copy(tar_file.name, dest) # peek inside for .ckpt files try: with tarfile.open(dest, "r:*") as tf: members = tf.getnames() except Exception as e: return f"failed to open tar: {e}", gr.update(choices=list_saved_models()) ckpts = [m for m in members if m.endswith(".ckpt")] ckpt_msg = f"found {len(ckpts)} .ckpt file(s): {ckpts}" if ckpts else "no .ckpt files found (still saved)" return f"saved as `{filename}`. {ckpt_msg}", gr.update(choices=list_saved_models()) def load_model(model_tar_name): global loaded_model, loaded_tokenizer, loaded_model_name if not model_tar_name or model_tar_name == "(none)": return "select a model first." tar_path = os.path.join(UPLOAD_DIR, model_tar_name) if not os.path.exists(tar_path): return f"tar not found: {tar_path}" # clean and re-extract if os.path.exists(ACTIVE_MODEL_DIR): shutil.rmtree(ACTIVE_MODEL_DIR) os.makedirs(ACTIVE_MODEL_DIR) try: with tarfile.open(tar_path, "r:*") as tf: tf.extractall(ACTIVE_MODEL_DIR) except Exception as e: return f"extraction failed: {e}" # find .ckpt files ckpts = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "*.ckpt"), recursive=True) # try loading as HF model first (config.json present), else fall back to ckpt hf_configs = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "config.json"), recursive=True) try: if hf_configs: model_dir = os.path.dirname(hf_configs[0]) loaded_tokenizer = GPT2Tokenizer.from_pretrained(model_dir) loaded_model = GPT2LMHeadModel.from_pretrained(model_dir) loaded_model.eval() loaded_model_name = model_tar_name return f"loaded HF-format model from `{model_dir}`. ckpts present: {[os.path.basename(c) for c in ckpts]}" elif ckpts: # bare .ckpt — assume state_dict for gpt2 base config ckpt_path = ckpts[0] state = torch.load(ckpt_path, map_location="cpu") # handle common wrapper keys if isinstance(state, dict): if "state_dict" in state: state = state["state_dict"] elif "model" in state: state = state["model"] config = GPT2Config() loaded_model = GPT2LMHeadModel(config) loaded_model.load_state_dict(state, strict=False) loaded_model.eval() loaded_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") loaded_model_name = model_tar_name return f"loaded ckpt `{os.path.basename(ckpt_path)}` (strict=False, base gpt2 config)" else: return "no config.json or .ckpt found in tar — can't load." except Exception as e: return f"model load failed: {e}" def generate_text(prompt, max_new_tokens, temperature, top_p, top_k): global loaded_model, loaded_tokenizer, loaded_model_name if loaded_model is None: return "no model loaded. upload and load one first." if not prompt.strip(): return "gimme a prompt." inputs = loaded_tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"] with torch.no_grad(): output = loaded_model.generate( input_ids, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), do_sample=True, pad_token_id=loaded_tokenizer.eos_token_id, ) generated = loaded_tokenizer.decode(output[0], skip_special_tokens=True) return generated # --- UI --- with gr.Blocks(title="gpt-2 model manager") as demo: gr.Markdown("## gpt-2 model uploader & generator\nupload a `.tar` or `.tar.gz` containing your model. supports HF-format dirs or bare `.ckpt` files.") with gr.Tab("upload & manage"): with gr.Row(): tar_input = gr.File(label="upload .tar / .tar.gz", file_types=[".tar", ".gz", ".tar.gz"]) upload_btn = gr.Button("upload & save") upload_status = gr.Textbox(label="status", interactive=False) gr.Markdown("---") model_dropdown = gr.Dropdown(label="saved models", choices=list_saved_models(), interactive=True) load_btn = gr.Button("load selected model") load_status = gr.Textbox(label="load status", interactive=False) upload_btn.click( upload_and_extract, inputs=[tar_input], outputs=[upload_status, model_dropdown], ) load_btn.click( load_model, inputs=[model_dropdown], outputs=[load_status], ) with gr.Tab("generate"): prompt_input = gr.Textbox(label="prompt", lines=3, placeholder="enter your prompt here...") with gr.Row(): max_tokens = gr.Slider(10, 500, value=100, step=10, label="max new tokens") temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature") with gr.Row(): top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top-p") top_k = gr.Slider(0, 100, value=50, step=1, label="top-k") gen_btn = gr.Button("generate") output_text = gr.Textbox(label="output", lines=10, interactive=False) gen_btn.click( generate_text, inputs=[prompt_input, max_tokens, temperature, top_p, top_k], outputs=[output_text], ) demo.launch()