Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from model.gpt_char_model import CharGPT | |
| from tokenizer import CharTokenizer | |
| def load_model(model_path="model/gpt_char_model.pth", block_size=32): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cpu") | |
| tokenizer = CharTokenizer() | |
| vocab_size = len(tokenizer.chars) | |
| model = CharGPT( | |
| vocab_size=vocab_size, | |
| block_size=block_size, | |
| n_layer=6, | |
| n_head=4, | |
| n_embd=256, | |
| ).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| return model, tokenizer, device | |
| def generate_username(seed_text="", min_length=1, max_length=16, temperature=1.0): | |
| model, tokenizer, device = load_model(model_path="model/gpt_char_model_v3.pth") | |
| input_ids = tokenizer.encode(seed_text) | |
| input_ids.insert(0, 0) | |
| input_ids = torch.tensor([input_ids], dtype=torch.long).to(device) | |
| for _ in range(max_length): | |
| input_crop = input_ids[:, -model.block_size :] | |
| logits = model(input_crop) | |
| logits = logits[:, -1, :] / temperature | |
| probs = torch.softmax(logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) | |
| next_char = tokenizer.decode(next_id[0].tolist()) | |
| if next_char == "\n": | |
| if input_ids.shape[1] < min_length: | |
| continue | |
| break | |
| input_ids = torch.cat((input_ids, next_id), dim=1) | |
| return tokenizer.decode(input_ids[0].tolist()).strip() | |
| def gradio_interface(seed_text, min_length, max_length, temperature): | |
| return generate_username( | |
| seed_text, int(min_length), int(max_length), float(temperature) | |
| ) | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown("# MCID Generator") | |
| with gr.Row(): | |
| seed = gr.Textbox(label="Start token", value="") | |
| with gr.Row(): | |
| with gr.Column(): | |
| min_length = gr.Slider(1, 32, value=1, step=1, label="Minimum length") | |
| max_length = gr.Slider(1, 32, value=16, step=1, label="Maximum length") | |
| temperature = gr.Slider(0.5, 2.0, value=1.0, step=0.05, label="Temperature") | |
| with gr.Row(): | |
| output = gr.Textbox(label="Generated username") | |
| generate_btn = gr.Button("Generate") | |
| generate_btn.click( | |
| gradio_interface, | |
| inputs=[seed, min_length, max_length, temperature], | |
| outputs=output, | |
| ) | |
| demo.launch(share=True) | |