CharRNN / app.py
hoom4n's picture
Upload 18 files
b6447fa verified
import json
from functools import partial
import torch
import gradio as gr
from src.model import CharRNN
from src.config import MODEL_HPARAMS
from src.preprocessing import text_encoder, text_decoder
from src.inference import generate_text
from src.utils import load_css, load_markdown
# -------------------
# Setup
# -------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Torch Device: {device}")
# Load vocab
with open("artifacts/vocab.json", "r") as f:
vocab = json.load(f)
char2id = {char:idx for idx, char in enumerate(vocab.keys())}
id2char = {idx:char for idx, char in enumerate(vocab.keys())}
encode_fn = partial(text_encoder, char2id=char2id)
decode_fn = partial(text_decoder, id2char=id2char)
# Load model
model = CharRNN(len(vocab), **MODEL_HPARAMS).to(device)
state_dict = torch.load("model/char_rnn_model_params.pt", weights_only=True, map_location=device)
model.load_state_dict(state_dict)
# -------------------
# Inference function
# -------------------
def generate_fn(text, max_len, temperature, include_prompt):
return generate_text(model, text, encode_fn, decode_fn,
device, max_len, temperature, include_prompt)
def set_english():
return (
gr.update(value=english_title, elem_classes=[]),
gr.update(value=english_summary, elem_classes=[]),
)
def set_persian():
return (
gr.update(value=persian_title, elem_classes=["persian"]),
gr.update(value=persian_summary, elem_classes=["persian"]),
)
# -------------------
# Gradio UI
# -------------------
css = load_css()
english_title = "# 🎭 Char-RNN: Generate Shakespearean Texts with GRU Network"
persian_title = "# 🎭 آفرینش متن به سبک شکسپیر با CharRNN"
english_summary = load_markdown("english_summary")
persian_summary = load_markdown("persian_summary")
# -------------------
# Gradio App
# -------------------
with gr.Blocks(css=css, title="CharRNN Shakespeare Generator") as demo:
title_md = gr.Markdown(english_title, elem_id="title")
with gr.Row():
english_btn = gr.Button("English")
persian_btn = gr.Button("فارسی (Persian)")
summary_md = gr.Markdown(english_summary, elem_id="summary")
# generation panel
with gr.Row(variant="panel"):
with gr.Column(scale=1, variant="panel"):
prompt_box = gr.Textbox(label="Prompt", lines=3)
max_len_slider = gr.Slider(10, 500, value=128, step=10, label="Max Length")
temp_slider = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
include_prompt = gr.Checkbox(value=True, label="Include Prompt in Output")
generate_btn = gr.Button("✨ Generate Text", variant="primary")
with gr.Column(scale=1, variant="panel"):
output_md = gr.Textbox(elem_id="help_text", lines=7, max_lines=None, interactive=False)
# events
generate_btn.click(generate_fn, inputs=[prompt_box, max_len_slider, temp_slider, include_prompt], outputs=output_md)
english_btn.click(set_english, outputs=[title_md, summary_md])
persian_btn.click(set_persian, outputs=[title_md, summary_md])
demo.launch()