Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,13 +9,24 @@ from torch.nn.utils.rnn import pad_sequence
|
|
| 9 |
import firebase_admin
|
| 10 |
from firebase_admin import credentials, firestore
|
| 11 |
|
|
|
|
| 12 |
# Define the model architecture
|
| 13 |
class CTCTransliterator(nn.Module):
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
super().__init__()
|
| 16 |
self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
|
| 17 |
-
self.lstm = nn.LSTM(hidden_dim,
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
|
| 20 |
self.dropout = nn.Dropout(dropout)
|
| 21 |
self.upsample_factor = upsample_factor
|
|
@@ -30,20 +41,26 @@ class CTCTransliterator(nn.Module):
|
|
| 30 |
|
| 31 |
# (seq_len, batch, hidden) → (batch, hidden, seq_len)
|
| 32 |
x = x.permute(1, 2, 0)
|
| 33 |
-
x = F.interpolate(x,
|
|
|
|
|
|
|
|
|
|
| 34 |
# → (batch, hidden, seq_len*upsample_factor)
|
| 35 |
-
x = x.permute(2, 0,
|
|
|
|
| 36 |
|
| 37 |
x = self.fc(x)
|
| 38 |
x = x.log_softmax(dim=2)
|
| 39 |
return x
|
| 40 |
|
|
|
|
| 41 |
# Firebase Cache System
|
| 42 |
class FirebaseCache:
|
|
|
|
| 43 |
def __init__(self):
|
| 44 |
self.db = None
|
| 45 |
self.init_firebase()
|
| 46 |
-
|
| 47 |
def init_firebase(self):
|
| 48 |
"""Initialize Firebase connection"""
|
| 49 |
try:
|
|
@@ -53,118 +70,127 @@ class FirebaseCache:
|
|
| 53 |
if os.getenv('FIREBASE_CREDENTIALS'):
|
| 54 |
# Parse credentials from environment variable
|
| 55 |
import base64
|
| 56 |
-
cred_data = json.loads(
|
|
|
|
|
|
|
| 57 |
cred = credentials.Certificate(cred_data)
|
| 58 |
elif os.path.exists('firebase-credentials.json'):
|
| 59 |
# For local development
|
| 60 |
cred = credentials.Certificate('firebase-credentials.json')
|
| 61 |
else:
|
| 62 |
-
print(
|
|
|
|
|
|
|
| 63 |
return
|
| 64 |
-
|
| 65 |
firebase_admin.initialize_app(cred)
|
| 66 |
self.db = firestore.client()
|
| 67 |
print("Firebase initialized successfully!")
|
| 68 |
else:
|
| 69 |
self.db = firestore.client()
|
| 70 |
-
|
| 71 |
except Exception as e:
|
| 72 |
print(f"Firebase initialization failed: {e}")
|
| 73 |
print("Falling back to local cache mode")
|
| 74 |
self.db = None
|
| 75 |
-
|
| 76 |
-
def _create_cache_key(self, input_text):
|
| 77 |
"""Create a safe document key for Firestore"""
|
| 78 |
import hashlib
|
| 79 |
# Create hash to handle special characters and length limits
|
| 80 |
-
key = f"{input_text}"
|
| 81 |
return hashlib.md5(key.encode()).hexdigest()
|
| 82 |
-
|
| 83 |
-
def get(self, input_text):
|
| 84 |
"""Get cached translation from Firebase"""
|
| 85 |
if not self.db:
|
| 86 |
return None
|
| 87 |
-
|
| 88 |
try:
|
| 89 |
-
doc_key = self._create_cache_key(input_text)
|
| 90 |
doc = self.db.collection('translations').document(doc_key).get()
|
| 91 |
-
|
| 92 |
if doc.exists:
|
| 93 |
data = doc.to_dict()
|
| 94 |
# Update usage count
|
| 95 |
self.db.collection('translations').document(doc_key).update({
|
| 96 |
-
'usage_count':
|
| 97 |
-
'
|
|
|
|
|
|
|
| 98 |
})
|
| 99 |
print(f"Cache hit: {input_text}")
|
| 100 |
return data.get('output', '')
|
| 101 |
-
|
| 102 |
return None
|
| 103 |
-
|
| 104 |
except Exception as e:
|
| 105 |
print(f"Cache read error: {e}")
|
| 106 |
return None
|
| 107 |
-
|
| 108 |
-
def set(self, input_text, output):
|
| 109 |
"""Store translation in Firebase"""
|
| 110 |
if not self.db:
|
| 111 |
return False
|
| 112 |
-
|
| 113 |
try:
|
| 114 |
-
doc_key = self._create_cache_key(input_text)
|
| 115 |
doc_data = {
|
| 116 |
'input': input_text,
|
|
|
|
| 117 |
'output': output,
|
| 118 |
'corrected_output': '',
|
| 119 |
'timestamp': datetime.now(),
|
| 120 |
'last_used': datetime.now(),
|
| 121 |
'usage_count': 1
|
| 122 |
}
|
| 123 |
-
|
| 124 |
self.db.collection('translations').document(doc_key).set(doc_data)
|
| 125 |
print(f"Cached: {input_text} → {output}")
|
| 126 |
return True
|
| 127 |
-
|
| 128 |
except Exception as e:
|
| 129 |
print(f"Cache write error: {e}")
|
| 130 |
return False
|
| 131 |
-
|
| 132 |
-
def update_correction(self, input_text, corrected_output):
|
| 133 |
"""Update translation with user correction"""
|
| 134 |
if not self.db:
|
| 135 |
return False
|
| 136 |
-
|
| 137 |
try:
|
| 138 |
-
doc_key = self._create_cache_key(input_text)
|
| 139 |
self.db.collection('translations').document(doc_key).update({
|
| 140 |
-
'corrected_output':
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
})
|
| 143 |
print(f"Correction saved: {input_text} → {corrected_output}")
|
| 144 |
return True
|
| 145 |
-
|
| 146 |
except Exception as e:
|
| 147 |
print(f"Correction save error: {e}")
|
| 148 |
return False
|
| 149 |
-
|
| 150 |
def get_stats(self):
|
| 151 |
"""Get cache statistics"""
|
| 152 |
if not self.db:
|
| 153 |
return "Firebase not connected"
|
| 154 |
-
|
| 155 |
try:
|
| 156 |
docs = self.db.collection('translations').get()
|
| 157 |
total = len(docs)
|
| 158 |
-
|
| 159 |
corrected = 0
|
| 160 |
total_usage = 0
|
| 161 |
-
|
| 162 |
for doc in docs:
|
| 163 |
data = doc.to_dict()
|
| 164 |
if data.get('corrected_output'):
|
| 165 |
corrected += 1
|
| 166 |
total_usage += data.get('usage_count', 0)
|
| 167 |
-
|
| 168 |
return f"""
|
| 169 |
Cache Statistics:
|
| 170 |
• Total translations: {total}
|
|
@@ -172,62 +198,67 @@ Cache Statistics:
|
|
| 172 |
• Total usage count: {total_usage}
|
| 173 |
• Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
|
| 174 |
""".strip()
|
| 175 |
-
|
| 176 |
except Exception as e:
|
| 177 |
return f"Error getting stats: {e}"
|
| 178 |
|
|
|
|
| 179 |
# Load vocabularies and model
|
| 180 |
def load_model_and_vocabs():
|
| 181 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 182 |
-
|
| 183 |
# Load vocabularies
|
| 184 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 185 |
latin_stoi = json.load(f)
|
| 186 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 187 |
latin_itos = json.load(f)
|
| 188 |
-
|
| 189 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 190 |
arabic_stoi = json.load(f)
|
| 191 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 192 |
-
arabic_itos= json.load(f)
|
| 193 |
-
|
| 194 |
# Initialize model
|
| 195 |
-
model = CTCTransliterator(
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
).to(device)
|
| 203 |
-
|
| 204 |
# Load trained weights
|
| 205 |
-
model.load_state_dict(
|
|
|
|
| 206 |
model.eval()
|
| 207 |
-
|
| 208 |
-
blank_id = arabic_stoi.get('<blank>', len(arabic_itos)-1)
|
| 209 |
return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
|
| 210 |
|
|
|
|
| 211 |
# Load everything at startup
|
| 212 |
-
model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs(
|
|
|
|
| 213 |
firebase_cache = FirebaseCache()
|
| 214 |
|
|
|
|
| 215 |
def encode_text(text, vocab):
|
| 216 |
"""Encode text using vocabulary"""
|
| 217 |
-
return torch.tensor([vocab.get(ch, 0) for ch in text.strip()],
|
|
|
|
|
|
|
| 218 |
|
| 219 |
def greedy_decode(log_probs, blank_id, itos, stoi):
|
| 220 |
"""
|
| 221 |
Decode CTC outputs using greedy decoding.
|
| 222 |
"""
|
| 223 |
-
eos_id = stoi.get('<eos>', len(stoi)-2)
|
| 224 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 225 |
results = []
|
| 226 |
raw_results = []
|
| 227 |
print(eos_id, blank_id)
|
| 228 |
print(stoi)
|
| 229 |
print(type(blank_id))
|
| 230 |
-
print(stoi.get('<eos>',0))
|
| 231 |
for i, pred in enumerate(preds):
|
| 232 |
prev = None
|
| 233 |
decoded = []
|
|
@@ -239,7 +270,7 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
|
|
| 239 |
break
|
| 240 |
# CTC collapse: skip blanks and repeated characters
|
| 241 |
if p != blank_id and p != prev:
|
| 242 |
-
decoded.append(itos[str(p)])
|
| 243 |
prev = p
|
| 244 |
raw_result.append(itos[str(p)])
|
| 245 |
|
|
@@ -249,110 +280,116 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
|
|
| 249 |
|
| 250 |
return results
|
| 251 |
|
|
|
|
| 252 |
def transliterate_latin_to_arabic(text):
|
| 253 |
"""Transliterate Latin script to Arabic script with Firebase caching"""
|
| 254 |
if not text.strip():
|
| 255 |
return ""
|
| 256 |
-
|
| 257 |
# Check Firebase cache first
|
| 258 |
cached_result = firebase_cache.get(text, "Latin → Arabic")
|
| 259 |
if cached_result:
|
| 260 |
return cached_result
|
| 261 |
-
|
| 262 |
try:
|
| 263 |
# Encode input text
|
| 264 |
src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
|
| 265 |
-
|
| 266 |
# Generate prediction
|
| 267 |
with torch.no_grad():
|
| 268 |
out = model(src)
|
| 269 |
-
|
| 270 |
# Decode output
|
| 271 |
decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
|
| 272 |
result = decoded[0] if decoded else ""
|
| 273 |
-
|
| 274 |
# Cache the result in Firebase
|
| 275 |
firebase_cache.set(text, "Latin → Arabic", result)
|
| 276 |
-
|
| 277 |
return result
|
| 278 |
-
|
| 279 |
except Exception as e:
|
| 280 |
return f"Error: {str(e)}"
|
| 281 |
|
|
|
|
| 282 |
def transliterate_arabic_to_latin(text):
|
| 283 |
"""Transliterate Arabic script to Latin script (placeholder)"""
|
| 284 |
return "Arabic to Latin transliteration not implemented yet."
|
| 285 |
|
| 286 |
-
|
|
|
|
| 287 |
"""Main transliteration function"""
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
|
| 291 |
-
def save_correction(input_text, corrected_output):
|
| 292 |
"""Save user correction to Firebase"""
|
| 293 |
-
if firebase_cache.update_correction(input_text,
|
|
|
|
| 294 |
return "Correction saved to the database! Thank you for improving the model."
|
| 295 |
else:
|
| 296 |
return "Could not save correction to databse."
|
| 297 |
|
|
|
|
| 298 |
# Arabic keyboard layout
|
| 299 |
-
arabic_keys = [
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
]
|
| 305 |
|
| 306 |
# Create Gradio interface
|
| 307 |
def create_interface():
|
| 308 |
-
with gr.Blocks(title="Darija Transliterator",
|
| 309 |
-
|
| 310 |
-
|
| 311 |
# Darija Transliterator
|
| 312 |
Convert between Latin script and Arabic script for Moroccan Darija
|
| 313 |
|
| 314 |
**Firebase-Powered**: Persistent caching across sessions
|
| 315 |
**Arabic Keyboard**: Built-in Arabic keyboard for corrections
|
| 316 |
**Real-time Stats**: Live usage analytics
|
| 317 |
-
"""
|
| 318 |
-
|
| 319 |
-
|
| 320 |
# Stats section
|
| 321 |
with gr.Row():
|
| 322 |
stats_btn = gr.Button("Show Statistics", variant="secondary")
|
| 323 |
-
stats_display = gr.Textbox(
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
with gr.Row():
|
| 331 |
with gr.Column(scale=1):
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
input_text = gr.Textbox(
|
| 334 |
placeholder="Enter text to transliterate...",
|
| 335 |
label="Input Text",
|
| 336 |
lines=4,
|
| 337 |
-
max_lines=10
|
| 338 |
-
|
| 339 |
-
|
| 340 |
with gr.Row():
|
| 341 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 342 |
-
translate_btn = gr.Button("Transliterate",
|
| 343 |
-
|
|
|
|
| 344 |
with gr.Column(scale=1):
|
| 345 |
-
output_text = gr.Textbox(
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
# Arabic Keyboard
|
| 353 |
gr.Markdown("### Arabic Keyboard")
|
| 354 |
gr.Markdown("*Click letters to edit the output text above*")
|
| 355 |
-
|
| 356 |
with gr.Group():
|
| 357 |
for row in arabic_keys:
|
| 358 |
with gr.Row():
|
|
@@ -360,114 +397,95 @@ def create_interface():
|
|
| 360 |
btn = gr.Button(char, size="sm", scale=1)
|
| 361 |
btn.click(
|
| 362 |
fn=None,
|
| 363 |
-
js=
|
|
|
|
| 364 |
inputs=[output_text],
|
| 365 |
outputs=[output_text],
|
| 366 |
show_progress=False,
|
| 367 |
-
queue=False
|
| 368 |
-
|
| 369 |
-
|
| 370 |
with gr.Row():
|
| 371 |
space_btn = gr.Button("Space", size="sm", scale=2)
|
| 372 |
-
backspace_btn = gr.Button("⌫ Backspace",
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
# Correction system
|
| 376 |
with gr.Group():
|
| 377 |
gr.Markdown("### Correction System")
|
| 378 |
-
correction_status = gr.Textbox(
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
# Keyboard utility buttons
|
| 386 |
-
space_btn.click(
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
js="() => ''",
|
| 407 |
-
outputs=[output_text],
|
| 408 |
-
show_progress=False,
|
| 409 |
-
queue=False
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
# Stats button
|
| 413 |
-
stats_btn.click(
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
outputs=[stats_display]
|
| 419 |
-
)
|
| 420 |
-
|
| 421 |
# Example inputs
|
| 422 |
gr.Markdown("### Examples")
|
| 423 |
-
examples = [
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
# Event handlers
|
| 439 |
-
translate_btn.click(
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
save_correction_btn.click(
|
| 460 |
-
fn=save_correction,
|
| 461 |
-
inputs=[input_text, output_text],
|
| 462 |
-
outputs=[correction_status]
|
| 463 |
-
).then(
|
| 464 |
-
fn=lambda: gr.update(visible=True),
|
| 465 |
-
outputs=[correction_status]
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
# Information
|
| 469 |
-
gr.Markdown(
|
| 470 |
-
"""
|
| 471 |
### About
|
| 472 |
This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
|
| 473 |
|
|
@@ -482,12 +500,12 @@ def create_interface():
|
|
| 482 |
1. Use the Arabic keyboard to correct any wrong translations
|
| 483 |
2. Click "Save Correction" to store your improvement
|
| 484 |
3. Your corrections help train better models for everyone!
|
| 485 |
-
"""
|
| 486 |
-
|
| 487 |
-
|
| 488 |
return demo
|
| 489 |
|
|
|
|
| 490 |
# Launch the app
|
| 491 |
if __name__ == "__main__":
|
| 492 |
demo = create_interface()
|
| 493 |
-
demo.launch(share=True)
|
|
|
|
| 9 |
import firebase_admin
|
| 10 |
from firebase_admin import credentials, firestore
|
| 11 |
|
| 12 |
+
|
| 13 |
# Define the model architecture
|
| 14 |
class CTCTransliterator(nn.Module):
|
| 15 |
+
|
| 16 |
+
def __init__(self,
|
| 17 |
+
input_dim,
|
| 18 |
+
hidden_dim,
|
| 19 |
+
output_dim,
|
| 20 |
+
num_layers=3,
|
| 21 |
+
dropout=0.3,
|
| 22 |
+
upsample_factor=3):
|
| 23 |
super().__init__()
|
| 24 |
self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
|
| 25 |
+
self.lstm = nn.LSTM(hidden_dim,
|
| 26 |
+
hidden_dim,
|
| 27 |
+
num_layers=num_layers,
|
| 28 |
+
bidirectional=True,
|
| 29 |
+
dropout=dropout)
|
| 30 |
self.layer_norm = nn.LayerNorm(hidden_dim * 2)
|
| 31 |
self.dropout = nn.Dropout(dropout)
|
| 32 |
self.upsample_factor = upsample_factor
|
|
|
|
| 41 |
|
| 42 |
# (seq_len, batch, hidden) → (batch, hidden, seq_len)
|
| 43 |
x = x.permute(1, 2, 0)
|
| 44 |
+
x = F.interpolate(x,
|
| 45 |
+
scale_factor=self.upsample_factor,
|
| 46 |
+
mode='linear',
|
| 47 |
+
align_corners=False)
|
| 48 |
# → (batch, hidden, seq_len*upsample_factor)
|
| 49 |
+
x = x.permute(2, 0,
|
| 50 |
+
1) # back to (seq_len*upsample_factor, batch, hidden)
|
| 51 |
|
| 52 |
x = self.fc(x)
|
| 53 |
x = x.log_softmax(dim=2)
|
| 54 |
return x
|
| 55 |
|
| 56 |
+
|
| 57 |
# Firebase Cache System
|
| 58 |
class FirebaseCache:
|
| 59 |
+
|
| 60 |
def __init__(self):
|
| 61 |
self.db = None
|
| 62 |
self.init_firebase()
|
| 63 |
+
|
| 64 |
def init_firebase(self):
|
| 65 |
"""Initialize Firebase connection"""
|
| 66 |
try:
|
|
|
|
| 70 |
if os.getenv('FIREBASE_CREDENTIALS'):
|
| 71 |
# Parse credentials from environment variable
|
| 72 |
import base64
|
| 73 |
+
cred_data = json.loads(
|
| 74 |
+
base64.b64decode(
|
| 75 |
+
os.getenv('FIREBASE_CREDENTIALS')).decode())
|
| 76 |
cred = credentials.Certificate(cred_data)
|
| 77 |
elif os.path.exists('firebase-credentials.json'):
|
| 78 |
# For local development
|
| 79 |
cred = credentials.Certificate('firebase-credentials.json')
|
| 80 |
else:
|
| 81 |
+
print(
|
| 82 |
+
"No Firebase credentials found. Using local cache fallback."
|
| 83 |
+
)
|
| 84 |
return
|
| 85 |
+
|
| 86 |
firebase_admin.initialize_app(cred)
|
| 87 |
self.db = firestore.client()
|
| 88 |
print("Firebase initialized successfully!")
|
| 89 |
else:
|
| 90 |
self.db = firestore.client()
|
| 91 |
+
|
| 92 |
except Exception as e:
|
| 93 |
print(f"Firebase initialization failed: {e}")
|
| 94 |
print("Falling back to local cache mode")
|
| 95 |
self.db = None
|
| 96 |
+
|
| 97 |
+
def _create_cache_key(self, input_text, direction):
|
| 98 |
"""Create a safe document key for Firestore"""
|
| 99 |
import hashlib
|
| 100 |
# Create hash to handle special characters and length limits
|
| 101 |
+
key = f"{input_text}_{direction}"
|
| 102 |
return hashlib.md5(key.encode()).hexdigest()
|
| 103 |
+
|
| 104 |
+
def get(self, input_text, direction):
|
| 105 |
"""Get cached translation from Firebase"""
|
| 106 |
if not self.db:
|
| 107 |
return None
|
| 108 |
+
|
| 109 |
try:
|
| 110 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 111 |
doc = self.db.collection('translations').document(doc_key).get()
|
| 112 |
+
|
| 113 |
if doc.exists:
|
| 114 |
data = doc.to_dict()
|
| 115 |
# Update usage count
|
| 116 |
self.db.collection('translations').document(doc_key).update({
|
| 117 |
+
'usage_count':
|
| 118 |
+
data.get('usage_count', 0) + 1,
|
| 119 |
+
'last_used':
|
| 120 |
+
datetime.now()
|
| 121 |
})
|
| 122 |
print(f"Cache hit: {input_text}")
|
| 123 |
return data.get('output', '')
|
| 124 |
+
|
| 125 |
return None
|
| 126 |
+
|
| 127 |
except Exception as e:
|
| 128 |
print(f"Cache read error: {e}")
|
| 129 |
return None
|
| 130 |
+
|
| 131 |
+
def set(self, input_text, direction, output):
|
| 132 |
"""Store translation in Firebase"""
|
| 133 |
if not self.db:
|
| 134 |
return False
|
| 135 |
+
|
| 136 |
try:
|
| 137 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 138 |
doc_data = {
|
| 139 |
'input': input_text,
|
| 140 |
+
'direction': direction,
|
| 141 |
'output': output,
|
| 142 |
'corrected_output': '',
|
| 143 |
'timestamp': datetime.now(),
|
| 144 |
'last_used': datetime.now(),
|
| 145 |
'usage_count': 1
|
| 146 |
}
|
| 147 |
+
|
| 148 |
self.db.collection('translations').document(doc_key).set(doc_data)
|
| 149 |
print(f"Cached: {input_text} → {output}")
|
| 150 |
return True
|
| 151 |
+
|
| 152 |
except Exception as e:
|
| 153 |
print(f"Cache write error: {e}")
|
| 154 |
return False
|
| 155 |
+
|
| 156 |
+
def update_correction(self, input_text, direction, corrected_output):
|
| 157 |
"""Update translation with user correction"""
|
| 158 |
if not self.db:
|
| 159 |
return False
|
| 160 |
+
|
| 161 |
try:
|
| 162 |
+
doc_key = self._create_cache_key(input_text, direction)
|
| 163 |
self.db.collection('translations').document(doc_key).update({
|
| 164 |
+
'corrected_output':
|
| 165 |
+
corrected_output,
|
| 166 |
+
'correction_timestamp':
|
| 167 |
+
datetime.now()
|
| 168 |
})
|
| 169 |
print(f"Correction saved: {input_text} → {corrected_output}")
|
| 170 |
return True
|
| 171 |
+
|
| 172 |
except Exception as e:
|
| 173 |
print(f"Correction save error: {e}")
|
| 174 |
return False
|
| 175 |
+
|
| 176 |
def get_stats(self):
|
| 177 |
"""Get cache statistics"""
|
| 178 |
if not self.db:
|
| 179 |
return "Firebase not connected"
|
| 180 |
+
|
| 181 |
try:
|
| 182 |
docs = self.db.collection('translations').get()
|
| 183 |
total = len(docs)
|
| 184 |
+
|
| 185 |
corrected = 0
|
| 186 |
total_usage = 0
|
| 187 |
+
|
| 188 |
for doc in docs:
|
| 189 |
data = doc.to_dict()
|
| 190 |
if data.get('corrected_output'):
|
| 191 |
corrected += 1
|
| 192 |
total_usage += data.get('usage_count', 0)
|
| 193 |
+
|
| 194 |
return f"""
|
| 195 |
Cache Statistics:
|
| 196 |
• Total translations: {total}
|
|
|
|
| 198 |
• Total usage count: {total_usage}
|
| 199 |
• Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
|
| 200 |
""".strip()
|
| 201 |
+
|
| 202 |
except Exception as e:
|
| 203 |
return f"Error getting stats: {e}"
|
| 204 |
|
| 205 |
+
|
| 206 |
# Load vocabularies and model
|
| 207 |
def load_model_and_vocabs():
|
| 208 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 209 |
+
|
| 210 |
# Load vocabularies
|
| 211 |
with open('latin_stoi.json', 'r', encoding='utf-8') as f:
|
| 212 |
latin_stoi = json.load(f)
|
| 213 |
with open('latin_itos.json', 'r', encoding='utf-8') as f:
|
| 214 |
latin_itos = json.load(f)
|
| 215 |
+
|
| 216 |
with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
|
| 217 |
arabic_stoi = json.load(f)
|
| 218 |
with open('arabic_itos.json', 'r', encoding='utf-8') as f:
|
| 219 |
+
arabic_itos = json.load(f)
|
| 220 |
+
|
| 221 |
# Initialize model
|
| 222 |
+
model = CTCTransliterator(len(latin_stoi),
|
| 223 |
+
256,
|
| 224 |
+
len(arabic_stoi),
|
| 225 |
+
num_layers=3,
|
| 226 |
+
dropout=0.3,
|
| 227 |
+
upsample_factor=2).to(device)
|
| 228 |
+
|
|
|
|
|
|
|
| 229 |
# Load trained weights
|
| 230 |
+
model.load_state_dict(
|
| 231 |
+
torch.load('best_model.pth', map_location=device, weights_only=False))
|
| 232 |
model.eval()
|
| 233 |
+
|
| 234 |
+
blank_id = arabic_stoi.get('<blank>', len(arabic_itos) - 1)
|
| 235 |
return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
|
| 236 |
|
| 237 |
+
|
| 238 |
# Load everything at startup
|
| 239 |
+
model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs(
|
| 240 |
+
)
|
| 241 |
firebase_cache = FirebaseCache()
|
| 242 |
|
| 243 |
+
|
| 244 |
def encode_text(text, vocab):
|
| 245 |
"""Encode text using vocabulary"""
|
| 246 |
+
return torch.tensor([vocab.get(ch, 0) for ch in text.strip()],
|
| 247 |
+
dtype=torch.long)
|
| 248 |
+
|
| 249 |
|
| 250 |
def greedy_decode(log_probs, blank_id, itos, stoi):
|
| 251 |
"""
|
| 252 |
Decode CTC outputs using greedy decoding.
|
| 253 |
"""
|
| 254 |
+
eos_id = stoi.get('<eos>', len(stoi) - 2)
|
| 255 |
preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
|
| 256 |
results = []
|
| 257 |
raw_results = []
|
| 258 |
print(eos_id, blank_id)
|
| 259 |
print(stoi)
|
| 260 |
print(type(blank_id))
|
| 261 |
+
print(stoi.get('<eos>', 0))
|
| 262 |
for i, pred in enumerate(preds):
|
| 263 |
prev = None
|
| 264 |
decoded = []
|
|
|
|
| 270 |
break
|
| 271 |
# CTC collapse: skip blanks and repeated characters
|
| 272 |
if p != blank_id and p != prev:
|
| 273 |
+
decoded.append(itos[str(p)])
|
| 274 |
prev = p
|
| 275 |
raw_result.append(itos[str(p)])
|
| 276 |
|
|
|
|
| 280 |
|
| 281 |
return results
|
| 282 |
|
| 283 |
+
|
| 284 |
def transliterate_latin_to_arabic(text):
|
| 285 |
"""Transliterate Latin script to Arabic script with Firebase caching"""
|
| 286 |
if not text.strip():
|
| 287 |
return ""
|
| 288 |
+
|
| 289 |
# Check Firebase cache first
|
| 290 |
cached_result = firebase_cache.get(text, "Latin → Arabic")
|
| 291 |
if cached_result:
|
| 292 |
return cached_result
|
| 293 |
+
|
| 294 |
try:
|
| 295 |
# Encode input text
|
| 296 |
src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
|
| 297 |
+
|
| 298 |
# Generate prediction
|
| 299 |
with torch.no_grad():
|
| 300 |
out = model(src)
|
| 301 |
+
|
| 302 |
# Decode output
|
| 303 |
decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
|
| 304 |
result = decoded[0] if decoded else ""
|
| 305 |
+
|
| 306 |
# Cache the result in Firebase
|
| 307 |
firebase_cache.set(text, "Latin → Arabic", result)
|
| 308 |
+
|
| 309 |
return result
|
| 310 |
+
|
| 311 |
except Exception as e:
|
| 312 |
return f"Error: {str(e)}"
|
| 313 |
|
| 314 |
+
|
| 315 |
def transliterate_arabic_to_latin(text):
|
| 316 |
"""Transliterate Arabic script to Latin script (placeholder)"""
|
| 317 |
return "Arabic to Latin transliteration not implemented yet."
|
| 318 |
|
| 319 |
+
|
| 320 |
+
def transliterate(text, direction):
|
| 321 |
"""Main transliteration function"""
|
| 322 |
+
if direction == "Latin → Arabic":
|
| 323 |
+
return transliterate_latin_to_arabic(text.lower())
|
| 324 |
+
else:
|
| 325 |
+
return transliterate_arabic_to_latin(text)
|
| 326 |
|
| 327 |
|
| 328 |
+
def save_correction(input_text, direction, corrected_output):
|
| 329 |
"""Save user correction to Firebase"""
|
| 330 |
+
if firebase_cache.update_correction(input_text, direction,
|
| 331 |
+
corrected_output):
|
| 332 |
return "Correction saved to the database! Thank you for improving the model."
|
| 333 |
else:
|
| 334 |
return "Could not save correction to databse."
|
| 335 |
|
| 336 |
+
|
| 337 |
# Arabic keyboard layout
|
| 338 |
+
arabic_keys = [['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'],
|
| 339 |
+
['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'],
|
| 340 |
+
['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'],
|
| 341 |
+
['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']]
|
| 342 |
+
|
|
|
|
| 343 |
|
| 344 |
# Create Gradio interface
|
| 345 |
def create_interface():
|
| 346 |
+
with gr.Blocks(title="Darija Transliterator",
|
| 347 |
+
theme=gr.themes.Soft()) as demo:
|
| 348 |
+
gr.Markdown("""
|
| 349 |
# Darija Transliterator
|
| 350 |
Convert between Latin script and Arabic script for Moroccan Darija
|
| 351 |
|
| 352 |
**Firebase-Powered**: Persistent caching across sessions
|
| 353 |
**Arabic Keyboard**: Built-in Arabic keyboard for corrections
|
| 354 |
**Real-time Stats**: Live usage analytics
|
| 355 |
+
""")
|
| 356 |
+
|
|
|
|
| 357 |
# Stats section
|
| 358 |
with gr.Row():
|
| 359 |
stats_btn = gr.Button("Show Statistics", variant="secondary")
|
| 360 |
+
stats_display = gr.Textbox(label="Firebase Statistics",
|
| 361 |
+
interactive=False,
|
| 362 |
+
visible=False,
|
| 363 |
+
lines=5)
|
| 364 |
+
|
|
|
|
|
|
|
| 365 |
with gr.Row():
|
| 366 |
with gr.Column(scale=1):
|
| 367 |
+
direction = gr.Radio(
|
| 368 |
+
choices=["Latin → Arabic"],
|
| 369 |
+
value="Latin → Arabic",
|
| 370 |
+
label="Translation Direction")
|
| 371 |
+
|
| 372 |
input_text = gr.Textbox(
|
| 373 |
placeholder="Enter text to transliterate...",
|
| 374 |
label="Input Text",
|
| 375 |
lines=4,
|
| 376 |
+
max_lines=10)
|
| 377 |
+
|
|
|
|
| 378 |
with gr.Row():
|
| 379 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 380 |
+
translate_btn = gr.Button("Transliterate",
|
| 381 |
+
variant="primary")
|
| 382 |
+
|
| 383 |
with gr.Column(scale=1):
|
| 384 |
+
output_text = gr.Textbox(label="Output",
|
| 385 |
+
lines=4,
|
| 386 |
+
max_lines=10,
|
| 387 |
+
interactive=True)
|
| 388 |
+
|
|
|
|
|
|
|
| 389 |
# Arabic Keyboard
|
| 390 |
gr.Markdown("### Arabic Keyboard")
|
| 391 |
gr.Markdown("*Click letters to edit the output text above*")
|
| 392 |
+
|
| 393 |
with gr.Group():
|
| 394 |
for row in arabic_keys:
|
| 395 |
with gr.Row():
|
|
|
|
| 397 |
btn = gr.Button(char, size="sm", scale=1)
|
| 398 |
btn.click(
|
| 399 |
fn=None,
|
| 400 |
+
js=
|
| 401 |
+
f"(output_text) => output_text + '{char}'",
|
| 402 |
inputs=[output_text],
|
| 403 |
outputs=[output_text],
|
| 404 |
show_progress=False,
|
| 405 |
+
queue=False)
|
| 406 |
+
|
|
|
|
| 407 |
with gr.Row():
|
| 408 |
space_btn = gr.Button("Space", size="sm", scale=2)
|
| 409 |
+
backspace_btn = gr.Button("⌫ Backspace",
|
| 410 |
+
size="sm",
|
| 411 |
+
scale=2)
|
| 412 |
+
clear_output_btn = gr.Button("Clear Output",
|
| 413 |
+
size="sm",
|
| 414 |
+
scale=2)
|
| 415 |
+
|
| 416 |
# Correction system
|
| 417 |
with gr.Group():
|
| 418 |
gr.Markdown("### Correction System")
|
| 419 |
+
correction_status = gr.Textbox(label="Status",
|
| 420 |
+
interactive=False,
|
| 421 |
+
visible=False)
|
| 422 |
+
save_correction_btn = gr.Button("Save Correction",
|
| 423 |
+
variant="secondary")
|
| 424 |
+
|
|
|
|
| 425 |
# Keyboard utility buttons
|
| 426 |
+
space_btn.click(fn=None,
|
| 427 |
+
js="(output_text) => output_text + ' '",
|
| 428 |
+
inputs=[output_text],
|
| 429 |
+
outputs=[output_text],
|
| 430 |
+
show_progress=False,
|
| 431 |
+
queue=False)
|
| 432 |
+
|
| 433 |
+
backspace_btn.click(fn=None,
|
| 434 |
+
js="(output_text) => output_text.slice(0, -1)",
|
| 435 |
+
inputs=[output_text],
|
| 436 |
+
outputs=[output_text],
|
| 437 |
+
show_progress=False,
|
| 438 |
+
queue=False)
|
| 439 |
+
|
| 440 |
+
clear_output_btn.click(fn=None,
|
| 441 |
+
js="() => ''",
|
| 442 |
+
outputs=[output_text],
|
| 443 |
+
show_progress=False,
|
| 444 |
+
queue=False)
|
| 445 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
# Stats button
|
| 447 |
+
stats_btn.click(fn=firebase_cache.get_stats,
|
| 448 |
+
outputs=[stats_display
|
| 449 |
+
]).then(fn=lambda: gr.update(visible=True),
|
| 450 |
+
outputs=[stats_display])
|
| 451 |
+
|
|
|
|
|
|
|
|
|
|
| 452 |
# Example inputs
|
| 453 |
gr.Markdown("### Examples")
|
| 454 |
+
examples = [["makay3nich bli katkhdam bzaf", "Latin → Arabic"],
|
| 455 |
+
[
|
| 456 |
+
"rah bayn dkchi li katdir kolchi 3ay9 bik",
|
| 457 |
+
"Latin → Arabic"
|
| 458 |
+
],
|
| 459 |
+
["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"],
|
| 460 |
+
["ghadi temchi f lkhedma mzyan", "Latin → Arabic"]]
|
| 461 |
+
|
| 462 |
+
gr.Examples(examples=examples,
|
| 463 |
+
inputs=[input_text, direction],
|
| 464 |
+
outputs=output_text,
|
| 465 |
+
fn=transliterate,
|
| 466 |
+
cache_examples=False)
|
| 467 |
+
|
|
|
|
| 468 |
# Event handlers
|
| 469 |
+
translate_btn.click(fn=transliterate,
|
| 470 |
+
inputs=[input_text, direction],
|
| 471 |
+
outputs=output_text).then(
|
| 472 |
+
fn=lambda: gr.update(visible=True),
|
| 473 |
+
outputs=[correction_status])
|
| 474 |
+
|
| 475 |
+
clear_btn.click(fn=lambda: ("", ""), outputs=[input_text, output_text])
|
| 476 |
+
|
| 477 |
+
input_text.submit(fn=transliterate,
|
| 478 |
+
inputs=[input_text, direction],
|
| 479 |
+
outputs=output_text)
|
| 480 |
+
|
| 481 |
+
save_correction_btn.click(fn=save_correction,
|
| 482 |
+
inputs=[input_text, direction, output_text],
|
| 483 |
+
outputs=[correction_status]).then(
|
| 484 |
+
fn=lambda: gr.update(visible=True),
|
| 485 |
+
outputs=[correction_status])
|
| 486 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
# Information
|
| 488 |
+
gr.Markdown("""
|
|
|
|
| 489 |
### About
|
| 490 |
This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
|
| 491 |
|
|
|
|
| 500 |
1. Use the Arabic keyboard to correct any wrong translations
|
| 501 |
2. Click "Save Correction" to store your improvement
|
| 502 |
3. Your corrections help train better models for everyone!
|
| 503 |
+
""")
|
| 504 |
+
|
|
|
|
| 505 |
return demo
|
| 506 |
|
| 507 |
+
|
| 508 |
# Launch the app
|
| 509 |
if __name__ == "__main__":
|
| 510 |
demo = create_interface()
|
| 511 |
+
demo.launch(share=True)
|