Upload 3 files
Browse files- app.py +82 -0
- dictionary.pkl +3 -0
- sincode_model.py +268 -0
app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
from sincode_model import BeamSearchDecoder
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import base64
|
| 6 |
+
|
| 7 |
+
st.set_page_config(page_title="සිංCode Prototype", page_icon="🇱🇰", layout="centered")
|
| 8 |
+
def add_bg_from_local(image_file):
|
| 9 |
+
try:
|
| 10 |
+
with open(image_file, "rb") as f:
|
| 11 |
+
data = f.read()
|
| 12 |
+
b64_data = base64.b64encode(data).decode()
|
| 13 |
+
|
| 14 |
+
st.markdown(
|
| 15 |
+
f"""
|
| 16 |
+
<style>
|
| 17 |
+
.stApp {{
|
| 18 |
+
background-image: linear-gradient(rgba(0,0,0,0.7), rgba(0,0,0,0.7)), url(data:image/png;base64,{b64_data});
|
| 19 |
+
background-size: cover;
|
| 20 |
+
background-position: center;
|
| 21 |
+
background-attachment: fixed;
|
| 22 |
+
}}
|
| 23 |
+
</style>
|
| 24 |
+
""",
|
| 25 |
+
unsafe_allow_html=True
|
| 26 |
+
)
|
| 27 |
+
except FileNotFoundError:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
@st.cache_resource
|
| 31 |
+
def load_system():
|
| 32 |
+
decoder = BeamSearchDecoder()
|
| 33 |
+
return decoder
|
| 34 |
+
|
| 35 |
+
background_path = "/content/drive/MyDrive/FYP/background.png"
|
| 36 |
+
add_bg_from_local(background_path)
|
| 37 |
+
|
| 38 |
+
with st.sidebar:
|
| 39 |
+
logo = Image.open("/content/drive/MyDrive/FYP/SinCodeLogo.jpg")
|
| 40 |
+
st.image(logo, width=200)
|
| 41 |
+
st.title("සිංCode Project")
|
| 42 |
+
st.info("Prototype")
|
| 43 |
+
st.markdown("### 🏗 Architecture")
|
| 44 |
+
st.success("""
|
| 45 |
+
**Hybrid Neuro-Symbolic Engine**
|
| 46 |
+
Combines rule-based speed with Deep Learning (XLM-R) context awareness.
|
| 47 |
+
|
| 48 |
+
**Adaptive Code-Switching**
|
| 49 |
+
Intelligently detects and preserves English contexts.
|
| 50 |
+
|
| 51 |
+
**Contextual Disambiguation**
|
| 52 |
+
Resolves Singlish ambiguity using sentence-level probability.
|
| 53 |
+
""")
|
| 54 |
+
|
| 55 |
+
st.markdown("---")
|
| 56 |
+
st.write("© 2026 Kalana Chandrasekara")
|
| 57 |
+
|
| 58 |
+
st.title("සිංCode: Context-Aware Transliteration")
|
| 59 |
+
st.markdown("Type Singlish sentences below. The system handles **code-mixing**, **ambiguity**, and **punctuation**.")
|
| 60 |
+
|
| 61 |
+
input_text = st.text_area("Input Text", height=100, placeholder="e.g., Singlish sentences type krnna")
|
| 62 |
+
|
| 63 |
+
if st.button("Transliterate", type="primary", use_container_width=True) and input_text:
|
| 64 |
+
try:
|
| 65 |
+
with st.spinner("Processing..."):
|
| 66 |
+
decoder = load_system()
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
result, trace_logs = decoder.decode(input_text)
|
| 69 |
+
end_time = time.time()
|
| 70 |
+
|
| 71 |
+
st.success("Transliteration Complete")
|
| 72 |
+
st.markdown(f"### {result}")
|
| 73 |
+
st.caption(f"Time: {round(end_time - start_time, 2)}s")
|
| 74 |
+
|
| 75 |
+
with st.expander("See How It Works (Debug Info)", expanded=True):
|
| 76 |
+
st.write("Below shows the candidate scoring for each word step:")
|
| 77 |
+
for log in trace_logs:
|
| 78 |
+
st.markdown(log)
|
| 79 |
+
st.divider()
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
st.error(f"Error: {e}")
|
dictionary.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7444b74f2fcf8f208e47f087ee778f11086eab74f54e4f3e07fb6cc06c88ea8
|
| 3 |
+
size 326599345
|
sincode_model.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import requests
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 7 |
+
|
| 8 |
+
# --- 0. SETUP ROBUST ENGLISH VOCAB ---
|
| 9 |
+
def load_english_corpus():
|
| 10 |
+
# 1. Define Core "Safety" Words
|
| 11 |
+
core_english = {
|
| 12 |
+
"transliteration", "sincode", "prototype", "assignment", "singlish",
|
| 13 |
+
"rest", "complete", "tutorial", "small", "mistakes", "game", "play",
|
| 14 |
+
"type", "test", "online", "code", "mixing", "project", "demo", "today",
|
| 15 |
+
"tomorrow", "presentation", "slide"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
url = "https://raw.githubusercontent.com/first20hours/google-10000-english/master/20k.txt"
|
| 19 |
+
file_path = "english_20k.txt"
|
| 20 |
+
|
| 21 |
+
download_success = False
|
| 22 |
+
|
| 23 |
+
# 2. Try to Load/Download 20k Corpus
|
| 24 |
+
if not os.path.exists(file_path):
|
| 25 |
+
try:
|
| 26 |
+
print("🌐 Downloading English Corpus...")
|
| 27 |
+
r = requests.get(url, timeout=5)
|
| 28 |
+
with open(file_path, "wb") as f:
|
| 29 |
+
f.write(r.content)
|
| 30 |
+
download_success = True
|
| 31 |
+
except:
|
| 32 |
+
print("Internet Warning: Could not download English corpus. Using fallback list.")
|
| 33 |
+
else:
|
| 34 |
+
download_success = True
|
| 35 |
+
|
| 36 |
+
# 3. Combine Lists
|
| 37 |
+
full_vocab = core_english.copy()
|
| 38 |
+
|
| 39 |
+
if download_success and os.path.exists(file_path):
|
| 40 |
+
try:
|
| 41 |
+
with open(file_path, "r") as f:
|
| 42 |
+
downloaded_words = set(f.read().splitlines())
|
| 43 |
+
full_vocab.update(downloaded_words)
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
print(f"English Vocab Loaded: {len(full_vocab)} words")
|
| 48 |
+
return full_vocab
|
| 49 |
+
|
| 50 |
+
ENGLISH_VOCAB = load_english_corpus()
|
| 51 |
+
|
| 52 |
+
# --- 1. RULE BASED ENGINE ---
|
| 53 |
+
# (Standard Rule Variables)
|
| 54 |
+
nVowels = 26
|
| 55 |
+
consonants = ["nnd", "nndh", "nng", "th", "dh", "gh", "ch", "ph", "bh", "jh", "sh", "GN", "KN", "Lu", "kh", "Th", "Dh", "S", "d", "c", "th", "t", "k", "D", "n", "p", "b", "m", "\\u005C" + "y", "Y", "y", "j", "l", "v", "w", "s", "h", "N", "L", "K", "G", "P", "B", "f", "g", "r"]
|
| 56 |
+
consonantsUni = ["ඬ", "ඳ", "ඟ", "ත", "ධ", "ඝ", "ච", "ඵ", "භ", "ඣ", "ෂ", "ඥ", "ඤ", "ළු", "ඛ", "ඨ", "ඪ", "ශ", "ද", "ච", "ත", "ට", "ක", "ඩ", "න", "ප", "බ", "ම", "ය", "ය", "ය", "ජ", "ල", "ව", "ව", "ස", "හ", "ණ", "ළ", "ඛ", "ඝ", "ඵ", "ඹ", "ෆ", "ග", "ර"]
|
| 57 |
+
vowels = ["oo", "o\\)", "oe", "aa", "a\\)", "Aa", "A\\)", "ae", "ii", "i\\)", "ie", "ee", "ea", "e\\)", "ei", "uu", "u\\)", "au", "\\a", "a", "A", "i", "e", "u", "o", "I"]
|
| 58 |
+
vowelsUni = ["ඌ", "ඕ", "ඕ", "ආ", "ආ", "ඈ", "ඈ", "ඈ", "ඊ", "ඊ", "ඊ", "ඊ", "ඒ", "ඒ", "ඒ", "ඌ", "ඌ", "ඖ", "ඇ", "අ", "ඇ", "ඉ", "එ", "උ", "ඔ", "ඓ"]
|
| 59 |
+
vowelModifiersUni = ["ූ", "ෝ", "ෝ", "ා", "ා", "ෑ", "ෑ", "ෑ", "ී", "ී", "ී", "ී", "ේ", "ේ", "ේ", "ූ", "ූ", "ෞ", "ැ", "", "ැ", "ි", "ෙ", "ු", "ො", "ෛ"]
|
| 60 |
+
specialConsonants = ["\\n", "\\h", "\\N", "\\R", "R", "\\r"]
|
| 61 |
+
specialConsonantsUni = ["ං", "ඃ", "ඞ", "ඍ", "ර්"+"\u200D", "ර්"+"\u200D"]
|
| 62 |
+
specialChar = ["ruu", "ru"]
|
| 63 |
+
specialCharUni = ["ෲ", "ෘ"]
|
| 64 |
+
|
| 65 |
+
def rule_based_transliterate(text):
|
| 66 |
+
for i in range(len(specialConsonants)):
|
| 67 |
+
text = text.replace(specialConsonants[i], specialConsonantsUni[i])
|
| 68 |
+
for i in range(len(specialCharUni)):
|
| 69 |
+
for j in range(len(consonants)):
|
| 70 |
+
s = consonants[j] + specialChar[i]
|
| 71 |
+
v = consonantsUni[j] + specialCharUni[i]
|
| 72 |
+
r = s.replace(s+"/G", "")
|
| 73 |
+
text = text.replace(r, v)
|
| 74 |
+
for j in range(len(consonants)):
|
| 75 |
+
for i in range(len(vowels)):
|
| 76 |
+
s = consonants[j] + "r" + vowels[i]
|
| 77 |
+
v = consonantsUni[j] + "්ර" + vowelModifiersUni[i]
|
| 78 |
+
r = s.replace(s+"/G", "")
|
| 79 |
+
text = text.replace(r, v)
|
| 80 |
+
s = consonants[j] + "r"
|
| 81 |
+
v = consonantsUni[j] + "්ර"
|
| 82 |
+
r = s.replace(s+"/G", "")
|
| 83 |
+
text = text.replace(r, v)
|
| 84 |
+
for i in range(len(consonants)):
|
| 85 |
+
for j in range(nVowels):
|
| 86 |
+
s = consonants[i] + vowels[j]
|
| 87 |
+
v = consonantsUni[i] + vowelModifiersUni[j]
|
| 88 |
+
r = s.replace(s+"/G", "")
|
| 89 |
+
text = text.replace(r, v)
|
| 90 |
+
for i in range(len(consonants)):
|
| 91 |
+
r = consonants[i].replace(consonants[i]+"/G", "")
|
| 92 |
+
text = text.replace(r, consonantsUni[i] + "්")
|
| 93 |
+
for i in range(len(vowels)):
|
| 94 |
+
r = vowels[i].replace(vowels[i]+"/G", "")
|
| 95 |
+
text = text.replace(r, vowelsUni[i])
|
| 96 |
+
return text
|
| 97 |
+
|
| 98 |
+
# --- 2. DICTIONARY ADAPTER ---
|
| 99 |
+
class DictionaryAdapter:
|
| 100 |
+
def __init__(self, dictionary_dict):
|
| 101 |
+
self.dictionary = dictionary_dict
|
| 102 |
+
|
| 103 |
+
def get_candidates(self, word):
|
| 104 |
+
cands = []
|
| 105 |
+
word_lower = word.lower()
|
| 106 |
+
|
| 107 |
+
# 1. English Corpus Check
|
| 108 |
+
if word_lower in ENGLISH_VOCAB:
|
| 109 |
+
cands.append(word)
|
| 110 |
+
|
| 111 |
+
# 2. Sinhala Dictionary Check
|
| 112 |
+
if word in self.dictionary:
|
| 113 |
+
cands.extend(self.dictionary[word])
|
| 114 |
+
elif word_lower in self.dictionary:
|
| 115 |
+
cands.extend(self.dictionary[word_lower])
|
| 116 |
+
|
| 117 |
+
# 3. Clean & Return
|
| 118 |
+
if cands:
|
| 119 |
+
return list(dict.fromkeys(cands))
|
| 120 |
+
|
| 121 |
+
# 4. Fallback: Subwords (Only if NO candidates found)
|
| 122 |
+
length = len(word)
|
| 123 |
+
if length > 3:
|
| 124 |
+
for i in range(2, length - 1):
|
| 125 |
+
part1 = word[:i]
|
| 126 |
+
part2 = word[i:]
|
| 127 |
+
p1_cands = self.dictionary.get(part1) or self.dictionary.get(part1.lower())
|
| 128 |
+
p2_cands = self.dictionary.get(part2) or self.dictionary.get(part2.lower())
|
| 129 |
+
|
| 130 |
+
if p1_cands and p2_cands:
|
| 131 |
+
cands1 = list(enumerate(p1_cands[:3]))
|
| 132 |
+
cands2 = list(enumerate(p2_cands[:3]))
|
| 133 |
+
for rank1, w1 in cands1:
|
| 134 |
+
for rank2, w2 in cands2:
|
| 135 |
+
cands.append(w1 + w2)
|
| 136 |
+
|
| 137 |
+
if cands:
|
| 138 |
+
return list(set(cands))
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
def get_rule_output(self, word):
|
| 142 |
+
return rule_based_transliterate(word)
|
| 143 |
+
|
| 144 |
+
# --- 3. BEAM SEARCH DECODER (With Enhanced Trace) ---
|
| 145 |
+
class BeamSearchDecoder:
|
| 146 |
+
def __init__(self, model_name="FacebookAI/xlm-roberta-base", dictionary_path="dictionary.pkl", device=None):
|
| 147 |
+
if device is None:
|
| 148 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 149 |
+
else:
|
| 150 |
+
self.device = device
|
| 151 |
+
|
| 152 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 153 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 154 |
+
self.model.to(self.device)
|
| 155 |
+
self.model.eval()
|
| 156 |
+
|
| 157 |
+
import pickle
|
| 158 |
+
with open(dictionary_path, "rb") as f:
|
| 159 |
+
d_data = pickle.load(f)
|
| 160 |
+
self.adapter = DictionaryAdapter(d_data)
|
| 161 |
+
|
| 162 |
+
def batch_score(self, contexts, candidates):
|
| 163 |
+
inputs = self.tokenizer(contexts, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
| 164 |
+
mask_token_id = self.tokenizer.mask_token_id
|
| 165 |
+
scores = []
|
| 166 |
+
with torch.no_grad():
|
| 167 |
+
outputs = self.model(**inputs)
|
| 168 |
+
logits = outputs.logits
|
| 169 |
+
|
| 170 |
+
for i, target in enumerate(candidates):
|
| 171 |
+
token_ids = inputs.input_ids[i]
|
| 172 |
+
mask_indices = (token_ids == mask_token_id).nonzero(as_tuple=True)
|
| 173 |
+
if len(mask_indices[0]) == 0:
|
| 174 |
+
scores.append(-100.0); continue
|
| 175 |
+
|
| 176 |
+
mask_pos = mask_indices[0].item()
|
| 177 |
+
probs = torch.softmax(logits[i, mask_pos, :], dim=0)
|
| 178 |
+
target_ids = self.tokenizer.encode(target, add_special_tokens=False)
|
| 179 |
+
|
| 180 |
+
if not target_ids:
|
| 181 |
+
scores.append(-100.0); continue
|
| 182 |
+
|
| 183 |
+
word_score = sum([math.log(probs[tid].item() + 1e-9) for tid in target_ids])
|
| 184 |
+
scores.append(word_score / len(target_ids))
|
| 185 |
+
return scores
|
| 186 |
+
|
| 187 |
+
def decode(self, sentence, beam_width=3):
|
| 188 |
+
words = sentence.split()
|
| 189 |
+
candidate_sets, penalties, future_context = [], [], []
|
| 190 |
+
punct_pattern = re.compile(r"^(\W*)(.*?)(\W*)$")
|
| 191 |
+
trace_logs = []
|
| 192 |
+
|
| 193 |
+
for raw in words:
|
| 194 |
+
match = punct_pattern.match(raw)
|
| 195 |
+
prefix, core, suffix = match.groups() if match else ("", raw, "")
|
| 196 |
+
|
| 197 |
+
if not core:
|
| 198 |
+
candidate_sets.append([raw]); penalties.append([0.0]); future_context.append(raw)
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# 1. Get Candidates
|
| 202 |
+
cands = self.adapter.get_candidates(core)
|
| 203 |
+
rule_cand = self.adapter.get_rule_output(core)
|
| 204 |
+
|
| 205 |
+
if not cands:
|
| 206 |
+
cands = [rule_cand]
|
| 207 |
+
curr_penalties = [0.0]
|
| 208 |
+
else:
|
| 209 |
+
curr_penalties = []
|
| 210 |
+
has_english = any(c.lower() in ENGLISH_VOCAB for c in cands)
|
| 211 |
+
|
| 212 |
+
for c in cands:
|
| 213 |
+
is_eng = c.lower() in ENGLISH_VOCAB
|
| 214 |
+
is_rule_match = (c == rule_cand)
|
| 215 |
+
|
| 216 |
+
if is_eng:
|
| 217 |
+
curr_penalties.append(0.0)
|
| 218 |
+
elif has_english:
|
| 219 |
+
curr_penalties.append(5.0)
|
| 220 |
+
elif is_rule_match:
|
| 221 |
+
curr_penalties.append(0.0)
|
| 222 |
+
else:
|
| 223 |
+
curr_penalties.append(2.0)
|
| 224 |
+
|
| 225 |
+
final_cands = [prefix + c + suffix for c in cands]
|
| 226 |
+
candidate_sets.append(final_cands[:6])
|
| 227 |
+
penalties.append(curr_penalties[:6])
|
| 228 |
+
best_idx = curr_penalties.index(min(curr_penalties))
|
| 229 |
+
future_context.append(final_cands[best_idx])
|
| 230 |
+
|
| 231 |
+
beam = [([], 0.0)]
|
| 232 |
+
for t in range(len(words)):
|
| 233 |
+
candidates = candidate_sets[t]
|
| 234 |
+
curr_penalties = penalties[t]
|
| 235 |
+
next_beam = []
|
| 236 |
+
|
| 237 |
+
batch_ctx, batch_tgt, batch_meta = [], [], []
|
| 238 |
+
|
| 239 |
+
for p_idx, (p_path, p_score) in enumerate(beam):
|
| 240 |
+
for c_idx, cand in enumerate(candidates):
|
| 241 |
+
future = future_context[t+1:] if t+1 < len(words) else []
|
| 242 |
+
ctx = " ".join(p_path + [self.tokenizer.mask_token] + future)
|
| 243 |
+
batch_ctx.append(ctx)
|
| 244 |
+
batch_tgt.append(cand)
|
| 245 |
+
batch_meta.append((p_idx, c_idx))
|
| 246 |
+
|
| 247 |
+
if batch_ctx:
|
| 248 |
+
scores = self.batch_score(batch_ctx, batch_tgt)
|
| 249 |
+
# --- TRACE LOGGING ---
|
| 250 |
+
step_log = f"**Step {t+1}: {words[t]}**\n"
|
| 251 |
+
for i, score in enumerate(scores):
|
| 252 |
+
p_idx, c_idx = batch_meta[i]
|
| 253 |
+
orig_path, orig_score = beam[p_idx]
|
| 254 |
+
final_score = score - curr_penalties[c_idx]
|
| 255 |
+
next_beam.append((orig_path + [batch_tgt[i]], orig_score + final_score))
|
| 256 |
+
|
| 257 |
+
# Add to log if score is reasonable (reduce noise)
|
| 258 |
+
if score > -25.0:
|
| 259 |
+
word = batch_tgt[i]
|
| 260 |
+
penalty = curr_penalties[c_idx]
|
| 261 |
+
step_log += f"- `{word}` (Pen: {penalty}) -> **{final_score:.2f}**\n"
|
| 262 |
+
trace_logs.append(step_log)
|
| 263 |
+
|
| 264 |
+
if not next_beam: continue
|
| 265 |
+
beam = sorted(next_beam, key=lambda x: x[1], reverse=True)[:beam_width]
|
| 266 |
+
|
| 267 |
+
final_output = " ".join(beam[0][0]) if beam else ""
|
| 268 |
+
return final_output, trace_logs
|