gpt2 / app.py
SomePersonAlt's picture
Create app.py
04fed95 verified
import gradio as gr
import tarfile
import os
import shutil
import glob
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
UPLOAD_DIR = "uploaded_models"
ACTIVE_MODEL_DIR = "active_model"
os.makedirs(UPLOAD_DIR, exist_ok=True)
# Global model state
loaded_model = None
loaded_tokenizer = None
loaded_model_name = None
def list_saved_models():
tars = glob.glob(os.path.join(UPLOAD_DIR, "*.tar*"))
return [os.path.basename(t) for t in tars] if tars else ["(none)"]
def upload_and_extract(tar_file):
if tar_file is None:
return "no file uploaded.", gr.update(choices=list_saved_models())
filename = os.path.basename(tar_file.name)
dest = os.path.join(UPLOAD_DIR, filename)
shutil.copy(tar_file.name, dest)
# peek inside for .ckpt files
try:
with tarfile.open(dest, "r:*") as tf:
members = tf.getnames()
except Exception as e:
return f"failed to open tar: {e}", gr.update(choices=list_saved_models())
ckpts = [m for m in members if m.endswith(".ckpt")]
ckpt_msg = f"found {len(ckpts)} .ckpt file(s): {ckpts}" if ckpts else "no .ckpt files found (still saved)"
return f"saved as `{filename}`. {ckpt_msg}", gr.update(choices=list_saved_models())
def load_model(model_tar_name):
global loaded_model, loaded_tokenizer, loaded_model_name
if not model_tar_name or model_tar_name == "(none)":
return "select a model first."
tar_path = os.path.join(UPLOAD_DIR, model_tar_name)
if not os.path.exists(tar_path):
return f"tar not found: {tar_path}"
# clean and re-extract
if os.path.exists(ACTIVE_MODEL_DIR):
shutil.rmtree(ACTIVE_MODEL_DIR)
os.makedirs(ACTIVE_MODEL_DIR)
try:
with tarfile.open(tar_path, "r:*") as tf:
tf.extractall(ACTIVE_MODEL_DIR)
except Exception as e:
return f"extraction failed: {e}"
# find .ckpt files
ckpts = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "*.ckpt"), recursive=True)
# try loading as HF model first (config.json present), else fall back to ckpt
hf_configs = glob.glob(os.path.join(ACTIVE_MODEL_DIR, "**", "config.json"), recursive=True)
try:
if hf_configs:
model_dir = os.path.dirname(hf_configs[0])
loaded_tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
loaded_model = GPT2LMHeadModel.from_pretrained(model_dir)
loaded_model.eval()
loaded_model_name = model_tar_name
return f"loaded HF-format model from `{model_dir}`. ckpts present: {[os.path.basename(c) for c in ckpts]}"
elif ckpts:
# bare .ckpt — assume state_dict for gpt2 base config
ckpt_path = ckpts[0]
state = torch.load(ckpt_path, map_location="cpu")
# handle common wrapper keys
if isinstance(state, dict):
if "state_dict" in state:
state = state["state_dict"]
elif "model" in state:
state = state["model"]
config = GPT2Config()
loaded_model = GPT2LMHeadModel(config)
loaded_model.load_state_dict(state, strict=False)
loaded_model.eval()
loaded_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
loaded_model_name = model_tar_name
return f"loaded ckpt `{os.path.basename(ckpt_path)}` (strict=False, base gpt2 config)"
else:
return "no config.json or .ckpt found in tar — can't load."
except Exception as e:
return f"model load failed: {e}"
def generate_text(prompt, max_new_tokens, temperature, top_p, top_k):
global loaded_model, loaded_tokenizer, loaded_model_name
if loaded_model is None:
return "no model loaded. upload and load one first."
if not prompt.strip():
return "gimme a prompt."
inputs = loaded_tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
output = loaded_model.generate(
input_ids,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
do_sample=True,
pad_token_id=loaded_tokenizer.eos_token_id,
)
generated = loaded_tokenizer.decode(output[0], skip_special_tokens=True)
return generated
# --- UI ---
with gr.Blocks(title="gpt-2 model manager") as demo:
gr.Markdown("## gpt-2 model uploader & generator\nupload a `.tar` or `.tar.gz` containing your model. supports HF-format dirs or bare `.ckpt` files.")
with gr.Tab("upload & manage"):
with gr.Row():
tar_input = gr.File(label="upload .tar / .tar.gz", file_types=[".tar", ".gz", ".tar.gz"])
upload_btn = gr.Button("upload & save")
upload_status = gr.Textbox(label="status", interactive=False)
gr.Markdown("---")
model_dropdown = gr.Dropdown(label="saved models", choices=list_saved_models(), interactive=True)
load_btn = gr.Button("load selected model")
load_status = gr.Textbox(label="load status", interactive=False)
upload_btn.click(
upload_and_extract,
inputs=[tar_input],
outputs=[upload_status, model_dropdown],
)
load_btn.click(
load_model,
inputs=[model_dropdown],
outputs=[load_status],
)
with gr.Tab("generate"):
prompt_input = gr.Textbox(label="prompt", lines=3, placeholder="enter your prompt here...")
with gr.Row():
max_tokens = gr.Slider(10, 500, value=100, step=10, label="max new tokens")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="temperature")
with gr.Row():
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top-p")
top_k = gr.Slider(0, 100, value=50, step=1, label="top-k")
gen_btn = gr.Button("generate")
output_text = gr.Textbox(label="output", lines=10, interactive=False)
gen_btn.click(
generate_text,
inputs=[prompt_input, max_tokens, temperature, top_p, top_k],
outputs=[output_text],
)
demo.launch()