mcid-generator / app.py
Yajii2
Add application file
1ffcc33 unverified
raw
history blame
2.46 kB
import gradio as gr
import torch
from model.gpt_char_model import CharGPT
from tokenizer import CharTokenizer
def load_model(model_path="model/gpt_char_model.pth", block_size=32):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
tokenizer = CharTokenizer()
vocab_size = len(tokenizer.chars)
model = CharGPT(
vocab_size=vocab_size,
block_size=block_size,
n_layer=6,
n_head=4,
n_embd=256,
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, tokenizer, device
@torch.no_grad()
def generate_username(seed_text="", min_length=1, max_length=16, temperature=1.0):
model, tokenizer, device = load_model(model_path="model/gpt_char_model_v3.pth")
input_ids = tokenizer.encode(seed_text)
input_ids.insert(0, 0)
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
for _ in range(max_length):
input_crop = input_ids[:, -model.block_size :]
logits = model(input_crop)
logits = logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
next_char = tokenizer.decode(next_id[0].tolist())
if next_char == "\n":
if input_ids.shape[1] < min_length:
continue
break
input_ids = torch.cat((input_ids, next_id), dim=1)
return tokenizer.decode(input_ids[0].tolist()).strip()
def gradio_interface(seed_text, min_length, max_length, temperature):
return generate_username(
seed_text, int(min_length), int(max_length), float(temperature)
)
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown("# MCID Generator")
with gr.Row():
seed = gr.Textbox(label="Start token", value="")
with gr.Row():
with gr.Column():
min_length = gr.Slider(1, 32, value=1, step=1, label="Minimum length")
max_length = gr.Slider(1, 32, value=16, step=1, label="Maximum length")
temperature = gr.Slider(0.5, 2.0, value=1.0, step=0.05, label="Temperature")
with gr.Row():
output = gr.Textbox(label="Generated username")
generate_btn = gr.Button("Generate")
generate_btn.click(
gradio_interface,
inputs=[seed, min_length, max_length, temperature],
outputs=output,
)
demo.launch(share=True)