Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import sentencepiece as spm | |
| import math | |
| import unicodedata | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| # ========================================== | |
| # 1. CONFIG | |
| # ========================================== | |
| class Config: | |
| D_MODEL = 256 | |
| N_HEADS = 8 | |
| N_ENC_LAYERS = 4 | |
| N_MEANING_DEC_LAYERS = 4 | |
| N_DOHA_DEC_LAYERS = 4 | |
| D_FF = 1024 | |
| DROPOUT = 0.15 | |
| MAX_SEQ_LEN = 256 | |
| MAX_MEANING_LEN = 60 | |
| MAX_DOHA_LEN = 48 | |
| PAD_ID = 0 | |
| BOS_ID = 2 | |
| EOS_ID = 3 | |
| GEN_TEMPERATURE = 0.8 | |
| GEN_TOP_K = 50 | |
| GEN_TOP_P = 0.92 | |
| GEN_REP_PENALTY = 1.3 | |
| GEN_DOHA_REP_PEN = 1.5 | |
| cfg = Config() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| HF_REPO_ID = "nikpatidar333/doha-generation-model_v2" | |
| # ========================================== | |
| # 2. MODEL ARCHITECTURE | |
| # ========================================== | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=512, dropout=0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| pos = torch.arange(0, max_len).unsqueeze(1).float() | |
| div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(pos * div) | |
| pe[:, 1::2] = torch.cos(pos * div) | |
| self.register_buffer("pe", pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return self.dropout(x + self.pe[:, :x.size(1)]) | |
| class DohaDecoderLayer(nn.Module): | |
| def __init__(self, d_model, n_heads, d_ff, dropout): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.norm4 = nn.LayerNorm(d_model) | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.cross_attn_enc = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.cross_attn_meaning = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout) | |
| ) | |
| self.gate = nn.Parameter(torch.tensor(0.5)) | |
| def forward(self, x, encoder_memory, meaning_memory, tgt_mask, enc_kpm, mean_kpm): | |
| res = x; x = self.norm1(x) | |
| attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask) | |
| x = res + attn_out | |
| res = x; x_norm = self.norm2(x) | |
| enc_out, _ = self.cross_attn_enc(x_norm, encoder_memory, encoder_memory, key_padding_mask=enc_kpm) | |
| mean_out, _ = self.cross_attn_meaning(x_norm, meaning_memory, meaning_memory, key_padding_mask=mean_kpm) | |
| g = torch.sigmoid(self.gate) | |
| x = res + g * enc_out + (1.0 - g) * mean_out | |
| res = x; x = self.norm3(x) | |
| x = res + self.ffn(x) | |
| return self.norm4(x) | |
| class UnifiedDohaModel(nn.Module): | |
| def __init__(self, vocab_size, d_model, n_heads, n_enc, n_mdec, n_ddec, d_ff, dropout, max_len, pad_id): | |
| super().__init__() | |
| self.pad_id = pad_id | |
| self.d_model = d_model | |
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id) | |
| self.pos_enc = PositionalEncoding(d_model, max_len, dropout) | |
| enc_l = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) | |
| self.encoder = nn.TransformerEncoder(enc_l, num_layers=n_enc) | |
| m_dec_l = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) | |
| self.meaning_decoder = nn.TransformerDecoder(m_dec_l, num_layers=n_mdec) | |
| self.doha_dec_layers = nn.ModuleList([DohaDecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_ddec)]) | |
| self.doha_dec_norm = nn.LayerNorm(d_model) | |
| self.meaning_proj = nn.Linear(d_model, vocab_size, bias=False) | |
| self.doha_proj = nn.Linear(d_model, vocab_size, bias=False) | |
| self.meaning_proj.weight = self.embedding.weight | |
| self.doha_proj.weight = self.embedding.weight | |
| def encode(self, src, src_mask): | |
| x = self.pos_enc(self.embedding(src) * math.sqrt(self.d_model)) | |
| return self.encoder(x, src_key_padding_mask=~src_mask) | |
| def decode_meaning(self, tgt, memory, src_mask): | |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) | |
| x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model)) | |
| return self.meaning_decoder(x, memory, tgt_mask=tgt_mask, memory_key_padding_mask=~src_mask) | |
| def decode_doha(self, tgt, enc_mem, mean_mem, enc_mask, mean_mask): | |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) | |
| x = self.pos_enc(self.embedding(tgt) * math.sqrt(self.d_model)) | |
| for layer in self.doha_dec_layers: | |
| x = layer(x, enc_mem, mean_mem, tgt_mask, ~enc_mask, ~mean_mask) | |
| return self.doha_dec_norm(x) | |
| # ========================================== | |
| # 3. LOAD MODEL | |
| # ========================================== | |
| print("Loading model from HuggingFace...") | |
| MODEL_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="best_model.pt") | |
| TOKENIZER_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="tokenizer.model") | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(TOKENIZER_PATH) | |
| VOCAB_SIZE = sp.get_piece_size() | |
| DANDAA_ID = sp.piece_to_id("॥") | |
| STOP_TOKENS = {cfg.EOS_ID, DANDAA_ID} | |
| model = UnifiedDohaModel( | |
| VOCAB_SIZE, cfg.D_MODEL, cfg.N_HEADS, cfg.N_ENC_LAYERS, | |
| cfg.N_MEANING_DEC_LAYERS, cfg.N_DOHA_DEC_LAYERS, | |
| cfg.D_FF, cfg.DROPOUT, cfg.MAX_SEQ_LEN, cfg.PAD_ID | |
| ).to(DEVICE) | |
| ckpt = torch.load(MODEL_PATH, map_location=DEVICE) | |
| state_dict = {k.replace("module.", ""): v for k, v in ckpt["model_state"].items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print(f"Model loaded. Vocab: {VOCAB_SIZE}") | |
| # ========================================== | |
| # 4. SAMPLING | |
| # ========================================== | |
| def top_k_top_p_sample(logits, temperature=0.8, top_k=50, top_p=0.92, past_ids=None, rep_penalty=1.3): | |
| logits = logits.squeeze(0).float() / temperature | |
| if past_ids and rep_penalty > 1.0: | |
| for tid in set(past_ids[-20:]): | |
| logits[tid] /= rep_penalty if logits[tid] > 0 else (1 / rep_penalty) | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[-1]] = float("-inf") | |
| if top_p < 1.0: | |
| p_sort, p_idx = torch.sort(F.softmax(logits, dim=-1), descending=True) | |
| cumsum = torch.cumsum(p_sort, dim=-1) | |
| p_sort[(cumsum - p_sort) > top_p] = 0.0 | |
| next_token = torch.multinomial(p_sort, 1) | |
| return p_idx.gather(-1, next_token) | |
| return torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| def generate(theme, context, temperature=0.8, top_k=50, top_p=0.92): | |
| enc_text = f"<theme> {theme} </theme> <context> {context} </context>" | |
| enc_ids = torch.tensor([sp.encode(enc_text)], dtype=torch.long).to(DEVICE) | |
| enc_mask = enc_ids != cfg.PAD_ID | |
| enc_mem = model.encode(enc_ids, enc_mask) | |
| mean_ids = [cfg.BOS_ID] | |
| for _ in range(cfg.MAX_MEANING_LEN): | |
| m_in = torch.tensor([mean_ids], dtype=torch.long).to(DEVICE) | |
| m_out = model.decode_meaning(m_in, enc_mem, enc_mask) | |
| next_id = top_k_top_p_sample( | |
| model.meaning_proj(m_out[:, -1, :]), | |
| temperature=temperature, top_k=top_k, top_p=top_p, | |
| past_ids=mean_ids, rep_penalty=cfg.GEN_REP_PENALTY | |
| ).item() | |
| if next_id == cfg.EOS_ID: | |
| break | |
| mean_ids.append(next_id) | |
| gen_meaning = sp.decode(mean_ids[1:]) | |
| enc_text_full = f"{enc_text} <meaning> {gen_meaning} </meaning>" | |
| enc_ids_full = torch.tensor([sp.encode(enc_text_full)], dtype=torch.long).to(DEVICE) | |
| enc_mask_full = enc_ids_full != cfg.PAD_ID | |
| enc_mem_full = model.encode(enc_ids_full, enc_mask_full) | |
| m_mem = model.decode_meaning( | |
| torch.tensor([mean_ids], dtype=torch.long).to(DEVICE), enc_mem_full, enc_mask_full | |
| ) | |
| m_mask = (torch.tensor([mean_ids]) != cfg.PAD_ID).to(DEVICE) | |
| doha_ids = [cfg.BOS_ID] | |
| for _ in range(cfg.MAX_DOHA_LEN): | |
| d_in = torch.tensor([doha_ids], dtype=torch.long).to(DEVICE) | |
| d_out = model.decode_doha(d_in, enc_mem_full, m_mem, enc_mask_full, m_mask) | |
| next_id = top_k_top_p_sample( | |
| model.doha_proj(d_out[:, -1, :]), | |
| temperature=temperature, top_k=top_k, top_p=top_p, | |
| past_ids=doha_ids, rep_penalty=cfg.GEN_DOHA_REP_PEN | |
| ).item() | |
| doha_ids.append(next_id) | |
| if next_id in STOP_TOKENS: | |
| break | |
| return gen_meaning, sp.decode(doha_ids[1:]) | |
| # ========================================== | |
| # 5. MATRA EVALUATION | |
| # ========================================== | |
| HALANT = "\u094D" | |
| ANUSVARA = "\u0902" | |
| CHANDRABINDU = "\u0901" | |
| VISARGA = "\u0903" | |
| NUKTA = "\u093C" | |
| SWAR_WEIGHT = { | |
| "\u0905":1,"\u0906":2,"\u0907":1,"\u0908":2,"\u0909":1,"\u090A":2, | |
| "\u090B":1,"\u090C":1,"\u090F":2,"\u0910":2,"\u0913":2,"\u0914":2, | |
| } | |
| MATRA_WEIGHT = { | |
| "\u093E":2,"\u093F":1,"\u0940":2,"\u0941":1,"\u0942":2,"\u0943":1, | |
| "\u0947":2,"\u0948":2,"\u094B":2,"\u094C":2, | |
| } | |
| def is_consonant(ch): | |
| cp = ord(ch) | |
| return (0x0915 <= cp <= 0x0939) or (0x0958 <= cp <= 0x095F) | |
| def tokenize_dev(word): | |
| word = unicodedata.normalize("NFC", word) | |
| tokens = [] | |
| chars = list(word) | |
| i, n = 0, len(chars) | |
| while i < n: | |
| ch = chars[i] | |
| if ch in SWAR_WEIGHT: | |
| weight = SWAR_WEIGHT[ch]; unit = ch; i += 1 | |
| while i < n and chars[i] in (ANUSVARA, VISARGA): | |
| weight = 2; unit += chars[i]; i += 1 | |
| tokens.append({"unit": unit, "weight": weight}) | |
| elif is_consonant(ch): | |
| unit = ch; i += 1 | |
| if i < n and chars[i] == NUKTA: | |
| unit += chars[i]; i += 1 | |
| if i < n and chars[i] == HALANT: | |
| unit += chars[i]; i += 1 | |
| if tokens: tokens[-1]["weight"] = 2 | |
| tokens.append({"unit": unit, "weight": 0}) | |
| else: | |
| mw = 1 | |
| if i < n and chars[i] in MATRA_WEIGHT: | |
| mw = MATRA_WEIGHT[chars[i]]; unit += chars[i]; i += 1 | |
| while i < n and chars[i] in (ANUSVARA, VISARGA, CHANDRABINDU): | |
| if chars[i] in (ANUSVARA, VISARGA): mw = 2 | |
| unit += chars[i]; i += 1 | |
| tokens.append({"unit": unit, "weight": mw}) | |
| else: | |
| i += 1 | |
| return tokens | |
| def count_matra(text): | |
| return sum(t["weight"] for t in tokenize_dev(text)) | |
| def parse_line(line): | |
| line = line.strip() | |
| if not line: return [] | |
| if "," in line: | |
| return [count_matra(p.strip()) for p in line.split(",") if p.strip()] | |
| words = line.split() | |
| charans = [] | |
| curr = 0 | |
| for w in words: | |
| curr += count_matra(w) | |
| if curr >= 13: | |
| charans.append(curr); curr = 0 | |
| if curr > 0: charans.append(curr) | |
| return charans | |
| def get_charan_matras(doha_text): | |
| text = doha_text.replace("॥", "।").strip() | |
| lines = [l.strip() for l in text.split("।") if l.strip()] | |
| out = [] | |
| for line in lines: | |
| out.extend(parse_line(line)) | |
| return out | |
| def compute_mas(cm): | |
| ideal = [13, 11, 13, 11] | |
| cm4 = (cm + [0]*4)[:4] | |
| return sum(abs(cm4[i] - ideal[i]) for i in range(4)) | |
| def evaluate_doha(doha_text): | |
| cm = get_charan_matras(doha_text) | |
| cm4 = (cm + [0]*4)[:4] | |
| mas = compute_mas(cm) | |
| total = sum(cm4) | |
| return cm4, mas, total | |
| # ========================================== | |
| # 6. BEST-OF-N | |
| # ========================================== | |
| def generate_best_of_n(theme, context, n, temperature, top_k, top_p): | |
| candidates = [] | |
| for _ in range(n): | |
| meaning, doha = generate(theme, context, temperature, int(top_k), top_p) | |
| cm4, mas, total = evaluate_doha(doha) | |
| candidates.append({"meaning": meaning, "doha": doha, "cm4": cm4, "mas": mas, "total": total}) | |
| candidates.sort(key=lambda x: x["mas"]) | |
| return candidates | |
| # ========================================== | |
| # 7. MAIN HANDLER | |
| # ========================================== | |
| def run_generation(theme, context, n_attempts, temperature, top_k, top_p): | |
| if not theme.strip(): | |
| return "⚠️ कृपया थीम दर्ज करें।", "", "", "" | |
| if not context.strip(): | |
| return "⚠️ कृपया संदर्भ दर्ज करें।", "", "", "" | |
| candidates = generate_best_of_n(theme, context, int(n_attempts), temperature, int(top_k), top_p) | |
| best = candidates[0] | |
| doha_out = best["doha"] | |
| meaning_out = best["meaning"] | |
| cm4 = best["cm4"] | |
| mas = best["mas"] | |
| total = best["total"] | |
| quality = "✅ Perfect!" if mas == 0 else ("🟡 Good" if mas <= 4 else "🔴 High deviation") | |
| metrics_out = ( | |
| "**📊 मात्रा विश्लेषण (Matra Analysis — Best Doha)**\n\n" | |
| "| चरण | आदर्श | प्राप्त | अंतर |\n" | |
| "|:---:|:-----:|:------:|:----:|\n" | |
| f"| चरण 1 | 13 | {cm4[0]} | {abs(cm4[0]-13)} |\n" | |
| f"| चरण 2 | 11 | {cm4[1]} | {abs(cm4[1]-11)} |\n" | |
| f"| चरण 3 | 13 | {cm4[2]} | {abs(cm4[2]-13)} |\n" | |
| f"| चरण 4 | 11 | {cm4[3]} | {abs(cm4[3]-11)} |\n\n" | |
| f"**कुल मात्राएँ:** {total} / 48 \n" | |
| f"**MAS:** {mas} {quality}" | |
| ) | |
| rows = [] | |
| for i, c in enumerate(candidates): | |
| marker = "⭐" if i == 0 else f"{i+1}." | |
| snippet = c["doha"][:55] + "..." if len(c["doha"]) > 55 else c["doha"] | |
| rows.append(f"| {marker} | {snippet} | {c['total']}/48 | {c['mas']} |") | |
| all_table = ( | |
| "**📋 सभी प्रयास — sorted by MAS (low = better)**\n\n" | |
| "| # | दोहा | मात्राएँ | MAS |\n" | |
| "|---|------|---------|-----|\n" | |
| + "\n".join(rows) | |
| ) | |
| return doha_out, meaning_out, metrics_out, all_table | |
| # ========================================== | |
| # 8. GRADIO 6 UI | |
| # ========================================== | |
| EXAMPLES = [ | |
| ["साहस", "अंधेरी सुरंग के अंत में प्रकाश की उम्मीद"], | |
| ["प्रेम", "सच्चा प्रेम"], | |
| ["विरह", "जुदाई का दुख"], | |
| ["गुरु", "गुरु से ज्ञान"], | |
| ["प्रकृति", "वसंत की सुंदरता"], | |
| ["कर्म", "कर्म का फल"], | |
| ["धैर्य", "नहीं रुकना"], | |
| ["आशा", "उम्मीद की किरण"], | |
| ["ज्ञान", "आत्मज्ञान जरूरी"], | |
| ["भक्ति", "दिल में भगवान"], | |
| ] | |
| with gr.Blocks(title="हिंदी दोहा जनरेटर | Hindi Doha Generator") as demo: | |
| gr.Markdown(""" | |
| # 🪔 हिंदी दोहा जनरेटर | |
| ### Constrained Generation of Theme-Based Hindi Dohas | |
| *Custom 14.7M parameter Transformer · Dual-Decoder Architecture · Two-Pass Inference* | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📝 Input") | |
| theme_input = gr.Textbox( | |
| label="थीम (Theme)", | |
| placeholder="e.g. साहस, प्रेम, विरह, गुरु...", | |
| value="साहस", | |
| lines=1 | |
| ) | |
| context_input = gr.Textbox( | |
| label="संदर्भ (Context)", | |
| placeholder="e.g. अंधेरी सुरंग के अंत में प्रकाश की उम्मीद", | |
| value="अंधेरी सुरंग के अंत में प्रकाश की उम्मीद", | |
| lines=2 | |
| ) | |
| gr.Markdown("### ⚙️ Generation Settings") | |
| n_attempts = gr.Slider(1, 10, value=3, step=1, | |
| label="Best-of-N Attempts") | |
| temperature = gr.Slider(0.5, 1.2, value=0.8, step=0.05, | |
| label="Temperature (creativity)") | |
| top_k = gr.Slider(10, 100, value=50, step=5, label="Top-K") | |
| top_p = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-P") | |
| generate_btn = gr.Button("🎵 दोहा जनरेट करें", variant="primary", size="lg") | |
| gr.Markdown("### 💡 Examples — click to load") | |
| gr.Examples(examples=EXAMPLES, inputs=[theme_input, context_input]) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📜 Output") | |
| doha_output = gr.Textbox( | |
| label="🎵 Generated Doha (दोहा)", | |
| lines=3, | |
| interactive=False | |
| ) | |
| meaning_output = gr.Textbox( | |
| label="💭 Generated Meaning (अर्थ)", | |
| lines=5, | |
| interactive=False | |
| ) | |
| metrics_output = gr.Markdown() | |
| with gr.Accordion("📋 All Attempts Comparison", open=False): | |
| all_attempts_output = gr.Markdown() | |
| gr.Markdown(""" | |
| --- | |
| ### 📖 About | |
| | Component | Detail | | |
| |-----------|--------| | |
| | Architecture | Custom Encoder + Dual Decoder Transformer | | |
| | Parameters | 14.7M | | |
| | Stage 1 | Span-corruption pretraining on 58,500 Hindi kavitas | | |
| | Stage 2 | Dual-decoder fine-tuning with learnable gated cross-attention | | |
| | Inference | Two-pass: meaning first, then doha | | |
| | Doha structure | 13+11 \\| 13+11 matras (48 total ideal) | | |
| | MAS formula | \\|m1−13\\| + \\|m2−11\\| + \\|m3−13\\| + \\|m4−11\\| (0 = perfect) | | |
| | Checkpoint | [nikpatidar333/doha-generation-model_v2](https://huggingface.co/nikpatidar333/doha-generation-model_v2) | | |
| """) | |
| generate_btn.click( | |
| fn=run_generation, | |
| inputs=[theme_input, context_input, n_attempts, temperature, top_k, top_p], | |
| outputs=[doha_output, meaning_output, metrics_output, all_attempts_output] | |
| ) | |
| demo.launch() |