| import json |
| from functools import partial |
|
|
| import torch |
| import gradio as gr |
|
|
| from src.model import CharRNN |
| from src.config import MODEL_HPARAMS |
| from src.preprocessing import text_encoder, text_decoder |
| from src.inference import generate_text |
| from src.utils import load_css, load_markdown |
|
|
| |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Torch Device: {device}") |
|
|
| |
| with open("artifacts/vocab.json", "r") as f: |
| vocab = json.load(f) |
|
|
| char2id = {char:idx for idx, char in enumerate(vocab.keys())} |
| id2char = {idx:char for idx, char in enumerate(vocab.keys())} |
|
|
| encode_fn = partial(text_encoder, char2id=char2id) |
| decode_fn = partial(text_decoder, id2char=id2char) |
|
|
| |
| model = CharRNN(len(vocab), **MODEL_HPARAMS).to(device) |
| state_dict = torch.load("model/char_rnn_model_params.pt", weights_only=True, map_location=device) |
| model.load_state_dict(state_dict) |
|
|
| |
| |
| |
| def generate_fn(text, max_len, temperature, include_prompt): |
| return generate_text(model, text, encode_fn, decode_fn, |
| device, max_len, temperature, include_prompt) |
|
|
| def set_english(): |
| return ( |
| gr.update(value=english_title, elem_classes=[]), |
| gr.update(value=english_summary, elem_classes=[]), |
| ) |
|
|
| def set_persian(): |
| return ( |
| gr.update(value=persian_title, elem_classes=["persian"]), |
| gr.update(value=persian_summary, elem_classes=["persian"]), |
| ) |
|
|
| |
| |
| |
| css = load_css() |
| english_title = "# 🎭 Char-RNN: Generate Shakespearean Texts with GRU Network" |
| persian_title = "# 🎭 آفرینش متن به سبک شکسپیر با CharRNN" |
| english_summary = load_markdown("english_summary") |
| persian_summary = load_markdown("persian_summary") |
|
|
| |
| |
| |
| with gr.Blocks(css=css, title="CharRNN Shakespeare Generator") as demo: |
| title_md = gr.Markdown(english_title, elem_id="title") |
|
|
| with gr.Row(): |
| english_btn = gr.Button("English") |
| persian_btn = gr.Button("فارسی (Persian)") |
|
|
| summary_md = gr.Markdown(english_summary, elem_id="summary") |
|
|
| |
| with gr.Row(variant="panel"): |
| with gr.Column(scale=1, variant="panel"): |
| prompt_box = gr.Textbox(label="Prompt", lines=3) |
| max_len_slider = gr.Slider(10, 500, value=128, step=10, label="Max Length") |
| temp_slider = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature") |
| include_prompt = gr.Checkbox(value=True, label="Include Prompt in Output") |
| generate_btn = gr.Button("✨ Generate Text", variant="primary") |
|
|
| with gr.Column(scale=1, variant="panel"): |
| output_md = gr.Textbox(elem_id="help_text", lines=7, max_lines=None, interactive=False) |
|
|
| |
| generate_btn.click(generate_fn, inputs=[prompt_box, max_len_slider, temp_slider, include_prompt], outputs=output_md) |
| english_btn.click(set_english, outputs=[title_md, summary_md]) |
| persian_btn.click(set_persian, outputs=[title_md, summary_md]) |
|
|
| demo.launch() |