import torch import torch.nn as nn import torch.nn.functional as F import sentencepiece as spm import math import unicodedata import gradio as gr from huggingface_hub import hf_hub_download # ========================================== # 1. CONFIG # ========================================== class Config: D_MODEL = 256 N_HEADS = 8 N_ENC_LAYERS = 4 N_MEANING_DEC_LAYERS = 4 N_DOHA_DEC_LAYERS = 4 D_FF = 1024 DROPOUT = 0.15 MAX_SEQ_LEN = 256 MAX_MEANING_LEN = 60 MAX_DOHA_LEN = 48 PAD_ID = 0 BOS_ID = 2 EOS_ID = 3 GEN_TEMPERATURE = 0.8 GEN_TOP_K = 50 GEN_TOP_P = 0.92 GEN_REP_PENALTY = 1.3 GEN_DOHA_REP_PEN = 1.5 cfg = Config() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") HF_REPO_ID = "nikpatidar333/doha-generation-model_v2" # ========================================== # 2. MODEL ARCHITECTURE # ========================================== class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=512, dropout=0.1): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1).float() div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x): return self.dropout(x + self.pe[:, :x.size(1)]) class DohaDecoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.norm4 = nn.LayerNorm(d_model) self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.cross_attn_enc = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.cross_attn_meaning = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) self.gate = nn.Parameter(torch.tensor(0.5)) def forward(self, x, encoder_memory, meaning_memory, tgt_mask, enc_kpm, mean_kpm): res = x; x = self.norm1(x) attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask) x = res + attn_out res = x; x_norm = self.norm2(x) enc_out, _ = self.cross_attn_enc(x_norm, encoder_memory, encoder_memory, key_padding_mask=enc_kpm) mean_out, _ = self.cross_attn_meaning(x_norm, meaning_memory, meaning_memory, key_padding_mask=mean_kpm) g = torch.sigmoid(self.gate) x = res + g * enc_out + (1.0 - g) * mean_out res = x; x = self.norm3(x) x = res + self.ffn(x) return self.norm4(x) class UnifiedDohaModel(nn.Module): def __init__(self, vocab_size, d_model, n_heads, n_enc, n_mdec, n_ddec, d_ff, dropout, max_len, pad_id): super().__init__() self.pad_id = pad_id self.d_model = d_model self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) self.pos_enc = PositionalEncoding(d_model, max_len, dropout) enc_l = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder(enc_l, num_layers=n_enc) m_dec_l = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) self.meaning_decoder = nn.TransformerDecoder(m_dec_l, num_layers=n_mdec) self.doha_dec_layers = nn.ModuleList([DohaDecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_ddec)]) self.doha_dec_norm = nn.LayerNorm(d_model) self.meaning_proj = nn.Linear(d_model, vocab_size, bias=False) self.doha_proj = nn.Linear(d_model, vocab_size, bias=False) self.meaning_proj.weight = self.embedding.weight self.doha_proj.weight = self.embedding.weight def encode(self, src, src_mask): x = self.pos_enc(self.embedding(src) * math.sqrt(self.d_model)) return self.encoder(x, src_key_padding_mask=~src_mask) def decode_meaning(self, tgt, memory, src_mask): tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model)) return self.meaning_decoder(x, memory, tgt_mask=tgt_mask, memory_key_padding_mask=~src_mask) def decode_doha(self, tgt, enc_mem, mean_mem, enc_mask, mean_mask): tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model)) for layer in self.doha_dec_layers: x = layer(x, enc_mem, mean_mem, tgt_mask, ~enc_mask, ~mean_mask) return self.doha_dec_norm(x) # ========================================== # 3. LOAD MODEL # ========================================== print("Loading model from HuggingFace...") MODEL_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="best_model.pt") TOKENIZER_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="tokenizer.model") sp = spm.SentencePieceProcessor() sp.load(TOKENIZER_PATH) VOCAB_SIZE = sp.get_piece_size() DANDAA_ID = sp.piece_to_id("॥") STOP_TOKENS = {cfg.EOS_ID, DANDAA_ID} model = UnifiedDohaModel( VOCAB_SIZE, cfg.D_MODEL, cfg.N_HEADS, cfg.N_ENC_LAYERS, cfg.N_MEANING_DEC_LAYERS, cfg.N_DOHA_DEC_LAYERS, cfg.D_FF, cfg.DROPOUT, cfg.MAX_SEQ_LEN, cfg.PAD_ID ).to(DEVICE) ckpt = torch.load(MODEL_PATH, map_location=DEVICE) state_dict = {k.replace("module.", ""): v for k, v in ckpt["model_state"].items()} model.load_state_dict(state_dict) model.eval() print(f"Model loaded. Vocab: {VOCAB_SIZE}") # ========================================== # 4. SAMPLING # ========================================== def top_k_top_p_sample(logits, temperature=0.8, top_k=50, top_p=0.92, past_ids=None, rep_penalty=1.3): logits = logits.squeeze(0).float() / temperature if past_ids and rep_penalty > 1.0: for tid in set(past_ids[-20:]): logits[tid] /= rep_penalty if logits[tid] > 0 else (1 / rep_penalty) if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[-1]] = float("-inf") if top_p < 1.0: p_sort, p_idx = torch.sort(F.softmax(logits, dim=-1), descending=True) cumsum = torch.cumsum(p_sort, dim=-1) p_sort[(cumsum - p_sort) > top_p] = 0.0 next_token = torch.multinomial(p_sort, 1) return p_idx.gather(-1, next_token) return torch.multinomial(F.softmax(logits, dim=-1), 1) @torch.no_grad() def generate(theme, context, temperature=0.8, top_k=50, top_p=0.92): enc_text = f" {theme} {context} " enc_ids = torch.tensor([sp.encode(enc_text)], dtype=torch.long).to(DEVICE) enc_mask = enc_ids != cfg.PAD_ID enc_mem = model.encode(enc_ids, enc_mask) mean_ids = [cfg.BOS_ID] for _ in range(cfg.MAX_MEANING_LEN): m_in = torch.tensor([mean_ids], dtype=torch.long).to(DEVICE) m_out = model.decode_meaning(m_in, enc_mem, enc_mask) next_id = top_k_top_p_sample( model.meaning_proj(m_out[:, -1, :]), temperature=temperature, top_k=top_k, top_p=top_p, past_ids=mean_ids, rep_penalty=cfg.GEN_REP_PENALTY ).item() if next_id == cfg.EOS_ID: break mean_ids.append(next_id) gen_meaning = sp.decode(mean_ids[1:]) enc_text_full = f"{enc_text} {gen_meaning} " enc_ids_full = torch.tensor([sp.encode(enc_text_full)], dtype=torch.long).to(DEVICE) enc_mask_full = enc_ids_full != cfg.PAD_ID enc_mem_full = model.encode(enc_ids_full, enc_mask_full) m_mem = model.decode_meaning( torch.tensor([mean_ids], dtype=torch.long).to(DEVICE), enc_mem_full, enc_mask_full ) m_mask = (torch.tensor([mean_ids]) != cfg.PAD_ID).to(DEVICE) doha_ids = [cfg.BOS_ID] for _ in range(cfg.MAX_DOHA_LEN): d_in = torch.tensor([doha_ids], dtype=torch.long).to(DEVICE) d_out = model.decode_doha(d_in, enc_mem_full, m_mem, enc_mask_full, m_mask) next_id = top_k_top_p_sample( model.doha_proj(d_out[:, -1, :]), temperature=temperature, top_k=top_k, top_p=top_p, past_ids=doha_ids, rep_penalty=cfg.GEN_DOHA_REP_PEN ).item() doha_ids.append(next_id) if next_id in STOP_TOKENS: break return gen_meaning, sp.decode(doha_ids[1:]) # ========================================== # 5. MATRA EVALUATION # ========================================== HALANT = "\u094D" ANUSVARA = "\u0902" CHANDRABINDU = "\u0901" VISARGA = "\u0903" NUKTA = "\u093C" SWAR_WEIGHT = { "\u0905":1,"\u0906":2,"\u0907":1,"\u0908":2,"\u0909":1,"\u090A":2, "\u090B":1,"\u090C":1,"\u090F":2,"\u0910":2,"\u0913":2,"\u0914":2, } MATRA_WEIGHT = { "\u093E":2,"\u093F":1,"\u0940":2,"\u0941":1,"\u0942":2,"\u0943":1, "\u0947":2,"\u0948":2,"\u094B":2,"\u094C":2, } def is_consonant(ch): cp = ord(ch) return (0x0915 <= cp <= 0x0939) or (0x0958 <= cp <= 0x095F) def tokenize_dev(word): word = unicodedata.normalize("NFC", word) tokens = [] chars = list(word) i, n = 0, len(chars) while i < n: ch = chars[i] if ch in SWAR_WEIGHT: weight = SWAR_WEIGHT[ch]; unit = ch; i += 1 while i < n and chars[i] in (ANUSVARA, VISARGA): weight = 2; unit += chars[i]; i += 1 tokens.append({"unit": unit, "weight": weight}) elif is_consonant(ch): unit = ch; i += 1 if i < n and chars[i] == NUKTA: unit += chars[i]; i += 1 if i < n and chars[i] == HALANT: unit += chars[i]; i += 1 if tokens: tokens[-1]["weight"] = 2 tokens.append({"unit": unit, "weight": 0}) else: mw = 1 if i < n and chars[i] in MATRA_WEIGHT: mw = MATRA_WEIGHT[chars[i]]; unit += chars[i]; i += 1 while i < n and chars[i] in (ANUSVARA, VISARGA, CHANDRABINDU): if chars[i] in (ANUSVARA, VISARGA): mw = 2 unit += chars[i]; i += 1 tokens.append({"unit": unit, "weight": mw}) else: i += 1 return tokens def count_matra(text): return sum(t["weight"] for t in tokenize_dev(text)) def parse_line(line): line = line.strip() if not line: return [] if "," in line: return [count_matra(p.strip()) for p in line.split(",") if p.strip()] words = line.split() charans = [] curr = 0 for w in words: curr += count_matra(w) if curr >= 13: charans.append(curr); curr = 0 if curr > 0: charans.append(curr) return charans def get_charan_matras(doha_text): text = doha_text.replace("॥", "।").strip() lines = [l.strip() for l in text.split("।") if l.strip()] out = [] for line in lines: out.extend(parse_line(line)) return out def compute_mas(cm): ideal = [13, 11, 13, 11] cm4 = (cm + [0]*4)[:4] return sum(abs(cm4[i] - ideal[i]) for i in range(4)) def evaluate_doha(doha_text): cm = get_charan_matras(doha_text) cm4 = (cm + [0]*4)[:4] mas = compute_mas(cm) total = sum(cm4) return cm4, mas, total # ========================================== # 6. BEST-OF-N # ========================================== def generate_best_of_n(theme, context, n, temperature, top_k, top_p): candidates = [] for _ in range(n): meaning, doha = generate(theme, context, temperature, int(top_k), top_p) cm4, mas, total = evaluate_doha(doha) candidates.append({"meaning": meaning, "doha": doha, "cm4": cm4, "mas": mas, "total": total}) candidates.sort(key=lambda x: x["mas"]) return candidates # ========================================== # 7. MAIN HANDLER # ========================================== def run_generation(theme, context, n_attempts, temperature, top_k, top_p): if not theme.strip(): return "⚠️ कृपया थीम दर्ज करें।", "", "", "" if not context.strip(): return "⚠️ कृपया संदर्भ दर्ज करें।", "", "", "" candidates = generate_best_of_n(theme, context, int(n_attempts), temperature, int(top_k), top_p) best = candidates[0] doha_out = best["doha"] meaning_out = best["meaning"] cm4 = best["cm4"] mas = best["mas"] total = best["total"] quality = "✅ Perfect!" if mas == 0 else ("🟡 Good" if mas <= 4 else "🔴 High deviation") metrics_out = ( "**📊 मात्रा विश्लेषण (Matra Analysis — Best Doha)**\n\n" "| चरण | आदर्श | प्राप्त | अंतर |\n" "|:---:|:-----:|:------:|:----:|\n" f"| चरण 1 | 13 | {cm4[0]} | {abs(cm4[0]-13)} |\n" f"| चरण 2 | 11 | {cm4[1]} | {abs(cm4[1]-11)} |\n" f"| चरण 3 | 13 | {cm4[2]} | {abs(cm4[2]-13)} |\n" f"| चरण 4 | 11 | {cm4[3]} | {abs(cm4[3]-11)} |\n\n" f"**कुल मात्राएँ:** {total} / 48 \n" f"**MAS:** {mas} {quality}" ) rows = [] for i, c in enumerate(candidates): marker = "⭐" if i == 0 else f"{i+1}." snippet = c["doha"][:55] + "..." if len(c["doha"]) > 55 else c["doha"] rows.append(f"| {marker} | {snippet} | {c['total']}/48 | {c['mas']} |") all_table = ( "**📋 सभी प्रयास — sorted by MAS (low = better)**\n\n" "| # | दोहा | मात्राएँ | MAS |\n" "|---|------|---------|-----|\n" + "\n".join(rows) ) return doha_out, meaning_out, metrics_out, all_table # ========================================== # 8. GRADIO 6 UI # ========================================== EXAMPLES = [ ["साहस", "अंधेरी सुरंग के अंत में प्रकाश की उम्मीद"], ["प्रेम", "सच्चा प्रेम"], ["विरह", "जुदाई का दुख"], ["गुरु", "गुरु से ज्ञान"], ["प्रकृति", "वसंत की सुंदरता"], ["कर्म", "कर्म का फल"], ["धैर्य", "नहीं रुकना"], ["आशा", "उम्मीद की किरण"], ["ज्ञान", "आत्मज्ञान जरूरी"], ["भक्ति", "दिल में भगवान"], ] with gr.Blocks(title="हिंदी दोहा जनरेटर | Hindi Doha Generator") as demo: gr.Markdown(""" # 🪔 हिंदी दोहा जनरेटर ### Constrained Generation of Theme-Based Hindi Dohas *Custom 14.7M parameter Transformer · Dual-Decoder Architecture · Two-Pass Inference* --- """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📝 Input") theme_input = gr.Textbox( label="थीम (Theme)", placeholder="e.g. साहस, प्रेम, विरह, गुरु...", value="साहस", lines=1 ) context_input = gr.Textbox( label="संदर्भ (Context)", placeholder="e.g. अंधेरी सुरंग के अंत में प्रकाश की उम्मीद", value="अंधेरी सुरंग के अंत में प्रकाश की उम्मीद", lines=2 ) gr.Markdown("### ⚙️ Generation Settings") n_attempts = gr.Slider(1, 10, value=3, step=1, label="Best-of-N Attempts") temperature = gr.Slider(0.5, 1.2, value=0.8, step=0.05, label="Temperature (creativity)") top_k = gr.Slider(10, 100, value=50, step=5, label="Top-K") top_p = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-P") generate_btn = gr.Button("🎵 दोहा जनरेट करें", variant="primary", size="lg") gr.Markdown("### 💡 Examples — click to load") gr.Examples(examples=EXAMPLES, inputs=[theme_input, context_input]) with gr.Column(scale=1): gr.Markdown("### 📜 Output") doha_output = gr.Textbox( label="🎵 Generated Doha (दोहा)", lines=3, interactive=False ) meaning_output = gr.Textbox( label="💭 Generated Meaning (अर्थ)", lines=5, interactive=False ) metrics_output = gr.Markdown() with gr.Accordion("📋 All Attempts Comparison", open=False): all_attempts_output = gr.Markdown() gr.Markdown(""" --- ### 📖 About | Component | Detail | |-----------|--------| | Architecture | Custom Encoder + Dual Decoder Transformer | | Parameters | 14.7M | | Stage 1 | Span-corruption pretraining on 58,500 Hindi kavitas | | Stage 2 | Dual-decoder fine-tuning with learnable gated cross-attention | | Inference | Two-pass: meaning first, then doha | | Doha structure | 13+11 \\| 13+11 matras (48 total ideal) | | MAS formula | \\|m1−13\\| + \\|m2−11\\| + \\|m3−13\\| + \\|m4−11\\| (0 = perfect) | | Checkpoint | [nikpatidar333/doha-generation-model_v2](https://huggingface.co/nikpatidar333/doha-generation-model_v2) | """) generate_btn.click( fn=run_generation, inputs=[theme_input, context_input, n_attempts, temperature, top_k, top_p], outputs=[doha_output, meaning_output, metrics_output, all_attempts_output] ) demo.launch()