Doha-Generator / app.py
nikpatidar333's picture
Update app.py
4e1b4f8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
import math
import unicodedata
import gradio as gr
from huggingface_hub import hf_hub_download
# ==========================================
# 1. CONFIG
# ==========================================
class Config:
D_MODEL = 256
N_HEADS = 8
N_ENC_LAYERS = 4
N_MEANING_DEC_LAYERS = 4
N_DOHA_DEC_LAYERS = 4
D_FF = 1024
DROPOUT = 0.15
MAX_SEQ_LEN = 256
MAX_MEANING_LEN = 60
MAX_DOHA_LEN = 48
PAD_ID = 0
BOS_ID = 2
EOS_ID = 3
GEN_TEMPERATURE = 0.8
GEN_TOP_K = 50
GEN_TOP_P = 0.92
GEN_REP_PENALTY = 1.3
GEN_DOHA_REP_PEN = 1.5
cfg = Config()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_REPO_ID = "nikpatidar333/doha-generation-model_v2"
# ==========================================
# 2. MODEL ARCHITECTURE
# ==========================================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1)])
class DohaDecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.cross_attn_enc = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.cross_attn_meaning = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model), nn.Dropout(dropout)
)
self.gate = nn.Parameter(torch.tensor(0.5))
def forward(self, x, encoder_memory, meaning_memory, tgt_mask, enc_kpm, mean_kpm):
res = x; x = self.norm1(x)
attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
x = res + attn_out
res = x; x_norm = self.norm2(x)
enc_out, _ = self.cross_attn_enc(x_norm, encoder_memory, encoder_memory, key_padding_mask=enc_kpm)
mean_out, _ = self.cross_attn_meaning(x_norm, meaning_memory, meaning_memory, key_padding_mask=mean_kpm)
g = torch.sigmoid(self.gate)
x = res + g * enc_out + (1.0 - g) * mean_out
res = x; x = self.norm3(x)
x = res + self.ffn(x)
return self.norm4(x)
class UnifiedDohaModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_enc, n_mdec, n_ddec, d_ff, dropout, max_len, pad_id):
super().__init__()
self.pad_id = pad_id
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
self.pos_enc = PositionalEncoding(d_model, max_len, dropout)
enc_l = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(enc_l, num_layers=n_enc)
m_dec_l = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
self.meaning_decoder = nn.TransformerDecoder(m_dec_l, num_layers=n_mdec)
self.doha_dec_layers = nn.ModuleList([DohaDecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_ddec)])
self.doha_dec_norm = nn.LayerNorm(d_model)
self.meaning_proj = nn.Linear(d_model, vocab_size, bias=False)
self.doha_proj = nn.Linear(d_model, vocab_size, bias=False)
self.meaning_proj.weight = self.embedding.weight
self.doha_proj.weight = self.embedding.weight
def encode(self, src, src_mask):
x = self.pos_enc(self.embedding(src) * math.sqrt(self.d_model))
return self.encoder(x, src_key_padding_mask=~src_mask)
def decode_meaning(self, tgt, memory, src_mask):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model))
return self.meaning_decoder(x, memory, tgt_mask=tgt_mask, memory_key_padding_mask=~src_mask)
def decode_doha(self, tgt, enc_mem, mean_mem, enc_mask, mean_mask):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model))
for layer in self.doha_dec_layers:
x = layer(x, enc_mem, mean_mem, tgt_mask, ~enc_mask, ~mean_mask)
return self.doha_dec_norm(x)
# ==========================================
# 3. LOAD MODEL
# ==========================================
print("Loading model from HuggingFace...")
MODEL_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="best_model.pt")
TOKENIZER_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="tokenizer.model")
sp = spm.SentencePieceProcessor()
sp.load(TOKENIZER_PATH)
VOCAB_SIZE = sp.get_piece_size()
DANDAA_ID = sp.piece_to_id("॥")
STOP_TOKENS = {cfg.EOS_ID, DANDAA_ID}
model = UnifiedDohaModel(
VOCAB_SIZE, cfg.D_MODEL, cfg.N_HEADS, cfg.N_ENC_LAYERS,
cfg.N_MEANING_DEC_LAYERS, cfg.N_DOHA_DEC_LAYERS,
cfg.D_FF, cfg.DROPOUT, cfg.MAX_SEQ_LEN, cfg.PAD_ID
).to(DEVICE)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
state_dict = {k.replace("module.", ""): v for k, v in ckpt["model_state"].items()}
model.load_state_dict(state_dict)
model.eval()
print(f"Model loaded. Vocab: {VOCAB_SIZE}")
# ==========================================
# 4. SAMPLING
# ==========================================
def top_k_top_p_sample(logits, temperature=0.8, top_k=50, top_p=0.92, past_ids=None, rep_penalty=1.3):
logits = logits.squeeze(0).float() / temperature
if past_ids and rep_penalty > 1.0:
for tid in set(past_ids[-20:]):
logits[tid] /= rep_penalty if logits[tid] > 0 else (1 / rep_penalty)
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[-1]] = float("-inf")
if top_p < 1.0:
p_sort, p_idx = torch.sort(F.softmax(logits, dim=-1), descending=True)
cumsum = torch.cumsum(p_sort, dim=-1)
p_sort[(cumsum - p_sort) > top_p] = 0.0
next_token = torch.multinomial(p_sort, 1)
return p_idx.gather(-1, next_token)
return torch.multinomial(F.softmax(logits, dim=-1), 1)
@torch.no_grad()
def generate(theme, context, temperature=0.8, top_k=50, top_p=0.92):
enc_text = f"<theme> {theme} </theme> <context> {context} </context>"
enc_ids = torch.tensor([sp.encode(enc_text)], dtype=torch.long).to(DEVICE)
enc_mask = enc_ids != cfg.PAD_ID
enc_mem = model.encode(enc_ids, enc_mask)
mean_ids = [cfg.BOS_ID]
for _ in range(cfg.MAX_MEANING_LEN):
m_in = torch.tensor([mean_ids], dtype=torch.long).to(DEVICE)
m_out = model.decode_meaning(m_in, enc_mem, enc_mask)
next_id = top_k_top_p_sample(
model.meaning_proj(m_out[:, -1, :]),
temperature=temperature, top_k=top_k, top_p=top_p,
past_ids=mean_ids, rep_penalty=cfg.GEN_REP_PENALTY
).item()
if next_id == cfg.EOS_ID:
break
mean_ids.append(next_id)
gen_meaning = sp.decode(mean_ids[1:])
enc_text_full = f"{enc_text} <meaning> {gen_meaning} </meaning>"
enc_ids_full = torch.tensor([sp.encode(enc_text_full)], dtype=torch.long).to(DEVICE)
enc_mask_full = enc_ids_full != cfg.PAD_ID
enc_mem_full = model.encode(enc_ids_full, enc_mask_full)
m_mem = model.decode_meaning(
torch.tensor([mean_ids], dtype=torch.long).to(DEVICE), enc_mem_full, enc_mask_full
)
m_mask = (torch.tensor([mean_ids]) != cfg.PAD_ID).to(DEVICE)
doha_ids = [cfg.BOS_ID]
for _ in range(cfg.MAX_DOHA_LEN):
d_in = torch.tensor([doha_ids], dtype=torch.long).to(DEVICE)
d_out = model.decode_doha(d_in, enc_mem_full, m_mem, enc_mask_full, m_mask)
next_id = top_k_top_p_sample(
model.doha_proj(d_out[:, -1, :]),
temperature=temperature, top_k=top_k, top_p=top_p,
past_ids=doha_ids, rep_penalty=cfg.GEN_DOHA_REP_PEN
).item()
doha_ids.append(next_id)
if next_id in STOP_TOKENS:
break
return gen_meaning, sp.decode(doha_ids[1:])
# ==========================================
# 5. MATRA EVALUATION
# ==========================================
HALANT = "\u094D"
ANUSVARA = "\u0902"
CHANDRABINDU = "\u0901"
VISARGA = "\u0903"
NUKTA = "\u093C"
SWAR_WEIGHT = {
"\u0905":1,"\u0906":2,"\u0907":1,"\u0908":2,"\u0909":1,"\u090A":2,
"\u090B":1,"\u090C":1,"\u090F":2,"\u0910":2,"\u0913":2,"\u0914":2,
}
MATRA_WEIGHT = {
"\u093E":2,"\u093F":1,"\u0940":2,"\u0941":1,"\u0942":2,"\u0943":1,
"\u0947":2,"\u0948":2,"\u094B":2,"\u094C":2,
}
def is_consonant(ch):
cp = ord(ch)
return (0x0915 <= cp <= 0x0939) or (0x0958 <= cp <= 0x095F)
def tokenize_dev(word):
word = unicodedata.normalize("NFC", word)
tokens = []
chars = list(word)
i, n = 0, len(chars)
while i < n:
ch = chars[i]
if ch in SWAR_WEIGHT:
weight = SWAR_WEIGHT[ch]; unit = ch; i += 1
while i < n and chars[i] in (ANUSVARA, VISARGA):
weight = 2; unit += chars[i]; i += 1
tokens.append({"unit": unit, "weight": weight})
elif is_consonant(ch):
unit = ch; i += 1
if i < n and chars[i] == NUKTA:
unit += chars[i]; i += 1
if i < n and chars[i] == HALANT:
unit += chars[i]; i += 1
if tokens: tokens[-1]["weight"] = 2
tokens.append({"unit": unit, "weight": 0})
else:
mw = 1
if i < n and chars[i] in MATRA_WEIGHT:
mw = MATRA_WEIGHT[chars[i]]; unit += chars[i]; i += 1
while i < n and chars[i] in (ANUSVARA, VISARGA, CHANDRABINDU):
if chars[i] in (ANUSVARA, VISARGA): mw = 2
unit += chars[i]; i += 1
tokens.append({"unit": unit, "weight": mw})
else:
i += 1
return tokens
def count_matra(text):
return sum(t["weight"] for t in tokenize_dev(text))
def parse_line(line):
line = line.strip()
if not line: return []
if "," in line:
return [count_matra(p.strip()) for p in line.split(",") if p.strip()]
words = line.split()
charans = []
curr = 0
for w in words:
curr += count_matra(w)
if curr >= 13:
charans.append(curr); curr = 0
if curr > 0: charans.append(curr)
return charans
def get_charan_matras(doha_text):
text = doha_text.replace("॥", "।").strip()
lines = [l.strip() for l in text.split("।") if l.strip()]
out = []
for line in lines:
out.extend(parse_line(line))
return out
def compute_mas(cm):
ideal = [13, 11, 13, 11]
cm4 = (cm + [0]*4)[:4]
return sum(abs(cm4[i] - ideal[i]) for i in range(4))
def evaluate_doha(doha_text):
cm = get_charan_matras(doha_text)
cm4 = (cm + [0]*4)[:4]
mas = compute_mas(cm)
total = sum(cm4)
return cm4, mas, total
# ==========================================
# 6. BEST-OF-N
# ==========================================
def generate_best_of_n(theme, context, n, temperature, top_k, top_p):
candidates = []
for _ in range(n):
meaning, doha = generate(theme, context, temperature, int(top_k), top_p)
cm4, mas, total = evaluate_doha(doha)
candidates.append({"meaning": meaning, "doha": doha, "cm4": cm4, "mas": mas, "total": total})
candidates.sort(key=lambda x: x["mas"])
return candidates
# ==========================================
# 7. MAIN HANDLER
# ==========================================
def run_generation(theme, context, n_attempts, temperature, top_k, top_p):
if not theme.strip():
return "⚠️ कृपया थीम दर्ज करें।", "", "", ""
if not context.strip():
return "⚠️ कृपया संदर्भ दर्ज करें।", "", "", ""
candidates = generate_best_of_n(theme, context, int(n_attempts), temperature, int(top_k), top_p)
best = candidates[0]
doha_out = best["doha"]
meaning_out = best["meaning"]
cm4 = best["cm4"]
mas = best["mas"]
total = best["total"]
quality = "✅ Perfect!" if mas == 0 else ("🟡 Good" if mas <= 4 else "🔴 High deviation")
metrics_out = (
"**📊 मात्रा विश्लेषण (Matra Analysis — Best Doha)**\n\n"
"| चरण | आदर्श | प्राप्त | अंतर |\n"
"|:---:|:-----:|:------:|:----:|\n"
f"| चरण 1 | 13 | {cm4[0]} | {abs(cm4[0]-13)} |\n"
f"| चरण 2 | 11 | {cm4[1]} | {abs(cm4[1]-11)} |\n"
f"| चरण 3 | 13 | {cm4[2]} | {abs(cm4[2]-13)} |\n"
f"| चरण 4 | 11 | {cm4[3]} | {abs(cm4[3]-11)} |\n\n"
f"**कुल मात्राएँ:** {total} / 48 \n"
f"**MAS:** {mas} {quality}"
)
rows = []
for i, c in enumerate(candidates):
marker = "⭐" if i == 0 else f"{i+1}."
snippet = c["doha"][:55] + "..." if len(c["doha"]) > 55 else c["doha"]
rows.append(f"| {marker} | {snippet} | {c['total']}/48 | {c['mas']} |")
all_table = (
"**📋 सभी प्रयास — sorted by MAS (low = better)**\n\n"
"| # | दोहा | मात्राएँ | MAS |\n"
"|---|------|---------|-----|\n"
+ "\n".join(rows)
)
return doha_out, meaning_out, metrics_out, all_table
# ==========================================
# 8. GRADIO 6 UI
# ==========================================
EXAMPLES = [
["साहस", "अंधेरी सुरंग के अंत में प्रकाश की उम्मीद"],
["प्रेम", "सच्चा प्रेम"],
["विरह", "जुदाई का दुख"],
["गुरु", "गुरु से ज्ञान"],
["प्रकृति", "वसंत की सुंदरता"],
["कर्म", "कर्म का फल"],
["धैर्य", "नहीं रुकना"],
["आशा", "उम्मीद की किरण"],
["ज्ञान", "आत्मज्ञान जरूरी"],
["भक्ति", "दिल में भगवान"],
]
with gr.Blocks(title="हिंदी दोहा जनरेटर | Hindi Doha Generator") as demo:
gr.Markdown("""
# 🪔 हिंदी दोहा जनरेटर
### Constrained Generation of Theme-Based Hindi Dohas
*Custom 14.7M parameter Transformer · Dual-Decoder Architecture · Two-Pass Inference*
---
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📝 Input")
theme_input = gr.Textbox(
label="थीम (Theme)",
placeholder="e.g. साहस, प्रेम, विरह, गुरु...",
value="साहस",
lines=1
)
context_input = gr.Textbox(
label="संदर्भ (Context)",
placeholder="e.g. अंधेरी सुरंग के अंत में प्रकाश की उम्मीद",
value="अंधेरी सुरंग के अंत में प्रकाश की उम्मीद",
lines=2
)
gr.Markdown("### ⚙️ Generation Settings")
n_attempts = gr.Slider(1, 10, value=3, step=1,
label="Best-of-N Attempts")
temperature = gr.Slider(0.5, 1.2, value=0.8, step=0.05,
label="Temperature (creativity)")
top_k = gr.Slider(10, 100, value=50, step=5, label="Top-K")
top_p = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-P")
generate_btn = gr.Button("🎵 दोहा जनरेट करें", variant="primary", size="lg")
gr.Markdown("### 💡 Examples — click to load")
gr.Examples(examples=EXAMPLES, inputs=[theme_input, context_input])
with gr.Column(scale=1):
gr.Markdown("### 📜 Output")
doha_output = gr.Textbox(
label="🎵 Generated Doha (दोहा)",
lines=3,
interactive=False
)
meaning_output = gr.Textbox(
label="💭 Generated Meaning (अर्थ)",
lines=5,
interactive=False
)
metrics_output = gr.Markdown()
with gr.Accordion("📋 All Attempts Comparison", open=False):
all_attempts_output = gr.Markdown()
gr.Markdown("""
---
### 📖 About
| Component | Detail |
|-----------|--------|
| Architecture | Custom Encoder + Dual Decoder Transformer |
| Parameters | 14.7M |
| Stage 1 | Span-corruption pretraining on 58,500 Hindi kavitas |
| Stage 2 | Dual-decoder fine-tuning with learnable gated cross-attention |
| Inference | Two-pass: meaning first, then doha |
| Doha structure | 13+11 \\| 13+11 matras (48 total ideal) |
| MAS formula | \\|m1−13\\| + \\|m2−11\\| + \\|m3−13\\| + \\|m4−11\\| (0 = perfect) |
| Checkpoint | [nikpatidar333/doha-generation-model_v2](https://huggingface.co/nikpatidar333/doha-generation-model_v2) |
""")
generate_btn.click(
fn=run_generation,
inputs=[theme_input, context_input, n_attempts, temperature, top_k, top_p],
outputs=[doha_output, meaning_output, metrics_output, all_attempts_output]
)
demo.launch()