Haitam03 commited on
Commit
cdec623
·
verified ·
1 Parent(s): c086db7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -204
app.py CHANGED
@@ -9,13 +9,24 @@ from torch.nn.utils.rnn import pad_sequence
9
  import firebase_admin
10
  from firebase_admin import credentials, firestore
11
 
 
12
  # Define the model architecture
13
  class CTCTransliterator(nn.Module):
14
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, upsample_factor=3):
 
 
 
 
 
 
 
15
  super().__init__()
16
  self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
17
- self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers,
18
- bidirectional=True, dropout=dropout)
 
 
 
19
  self.layer_norm = nn.LayerNorm(hidden_dim * 2)
20
  self.dropout = nn.Dropout(dropout)
21
  self.upsample_factor = upsample_factor
@@ -30,20 +41,26 @@ class CTCTransliterator(nn.Module):
30
 
31
  # (seq_len, batch, hidden) → (batch, hidden, seq_len)
32
  x = x.permute(1, 2, 0)
33
- x = F.interpolate(x, scale_factor=self.upsample_factor, mode='linear', align_corners=False)
 
 
 
34
  # → (batch, hidden, seq_len*upsample_factor)
35
- x = x.permute(2, 0, 1) # back to (seq_len*upsample_factor, batch, hidden)
 
36
 
37
  x = self.fc(x)
38
  x = x.log_softmax(dim=2)
39
  return x
40
 
 
41
  # Firebase Cache System
42
  class FirebaseCache:
 
43
  def __init__(self):
44
  self.db = None
45
  self.init_firebase()
46
-
47
  def init_firebase(self):
48
  """Initialize Firebase connection"""
49
  try:
@@ -53,118 +70,127 @@ class FirebaseCache:
53
  if os.getenv('FIREBASE_CREDENTIALS'):
54
  # Parse credentials from environment variable
55
  import base64
56
- cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode())
 
 
57
  cred = credentials.Certificate(cred_data)
58
  elif os.path.exists('firebase-credentials.json'):
59
  # For local development
60
  cred = credentials.Certificate('firebase-credentials.json')
61
  else:
62
- print("No Firebase credentials found. Using local cache fallback.")
 
 
63
  return
64
-
65
  firebase_admin.initialize_app(cred)
66
  self.db = firestore.client()
67
  print("Firebase initialized successfully!")
68
  else:
69
  self.db = firestore.client()
70
-
71
  except Exception as e:
72
  print(f"Firebase initialization failed: {e}")
73
  print("Falling back to local cache mode")
74
  self.db = None
75
-
76
- def _create_cache_key(self, input_text):
77
  """Create a safe document key for Firestore"""
78
  import hashlib
79
  # Create hash to handle special characters and length limits
80
- key = f"{input_text}"
81
  return hashlib.md5(key.encode()).hexdigest()
82
-
83
- def get(self, input_text):
84
  """Get cached translation from Firebase"""
85
  if not self.db:
86
  return None
87
-
88
  try:
89
- doc_key = self._create_cache_key(input_text)
90
  doc = self.db.collection('translations').document(doc_key).get()
91
-
92
  if doc.exists:
93
  data = doc.to_dict()
94
  # Update usage count
95
  self.db.collection('translations').document(doc_key).update({
96
- 'usage_count': data.get('usage_count', 0) + 1,
97
- 'last_used': datetime.now()
 
 
98
  })
99
  print(f"Cache hit: {input_text}")
100
  return data.get('output', '')
101
-
102
  return None
103
-
104
  except Exception as e:
105
  print(f"Cache read error: {e}")
106
  return None
107
-
108
- def set(self, input_text, output):
109
  """Store translation in Firebase"""
110
  if not self.db:
111
  return False
112
-
113
  try:
114
- doc_key = self._create_cache_key(input_text)
115
  doc_data = {
116
  'input': input_text,
 
117
  'output': output,
118
  'corrected_output': '',
119
  'timestamp': datetime.now(),
120
  'last_used': datetime.now(),
121
  'usage_count': 1
122
  }
123
-
124
  self.db.collection('translations').document(doc_key).set(doc_data)
125
  print(f"Cached: {input_text} → {output}")
126
  return True
127
-
128
  except Exception as e:
129
  print(f"Cache write error: {e}")
130
  return False
131
-
132
- def update_correction(self, input_text, corrected_output):
133
  """Update translation with user correction"""
134
  if not self.db:
135
  return False
136
-
137
  try:
138
- doc_key = self._create_cache_key(input_text)
139
  self.db.collection('translations').document(doc_key).update({
140
- 'corrected_output': corrected_output,
141
- 'correction_timestamp': datetime.now()
 
 
142
  })
143
  print(f"Correction saved: {input_text} → {corrected_output}")
144
  return True
145
-
146
  except Exception as e:
147
  print(f"Correction save error: {e}")
148
  return False
149
-
150
  def get_stats(self):
151
  """Get cache statistics"""
152
  if not self.db:
153
  return "Firebase not connected"
154
-
155
  try:
156
  docs = self.db.collection('translations').get()
157
  total = len(docs)
158
-
159
  corrected = 0
160
  total_usage = 0
161
-
162
  for doc in docs:
163
  data = doc.to_dict()
164
  if data.get('corrected_output'):
165
  corrected += 1
166
  total_usage += data.get('usage_count', 0)
167
-
168
  return f"""
169
  Cache Statistics:
170
  • Total translations: {total}
@@ -172,62 +198,67 @@ Cache Statistics:
172
  • Total usage count: {total_usage}
173
  • Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
174
  """.strip()
175
-
176
  except Exception as e:
177
  return f"Error getting stats: {e}"
178
 
 
179
  # Load vocabularies and model
180
  def load_model_and_vocabs():
181
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
182
-
183
  # Load vocabularies
184
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
185
  latin_stoi = json.load(f)
186
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
187
  latin_itos = json.load(f)
188
-
189
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
190
  arabic_stoi = json.load(f)
191
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
192
- arabic_itos= json.load(f)
193
-
194
  # Initialize model
195
- model = CTCTransliterator(
196
- len(latin_stoi),
197
- 256,
198
- len(arabic_stoi),
199
- num_layers=3,
200
- dropout=0.3,
201
- upsample_factor=2
202
- ).to(device)
203
-
204
  # Load trained weights
205
- model.load_state_dict(torch.load('best_model.pth', map_location=device, weights_only=False))
 
206
  model.eval()
207
-
208
- blank_id = arabic_stoi.get('<blank>', len(arabic_itos)-1)
209
  return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
210
 
 
211
  # Load everything at startup
212
- model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs()
 
213
  firebase_cache = FirebaseCache()
214
 
 
215
  def encode_text(text, vocab):
216
  """Encode text using vocabulary"""
217
- return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long)
 
 
218
 
219
  def greedy_decode(log_probs, blank_id, itos, stoi):
220
  """
221
  Decode CTC outputs using greedy decoding.
222
  """
223
- eos_id = stoi.get('<eos>', len(stoi)-2)
224
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
225
  results = []
226
  raw_results = []
227
  print(eos_id, blank_id)
228
  print(stoi)
229
  print(type(blank_id))
230
- print(stoi.get('<eos>',0))
231
  for i, pred in enumerate(preds):
232
  prev = None
233
  decoded = []
@@ -239,7 +270,7 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
239
  break
240
  # CTC collapse: skip blanks and repeated characters
241
  if p != blank_id and p != prev:
242
- decoded.append(itos[str(p)])
243
  prev = p
244
  raw_result.append(itos[str(p)])
245
 
@@ -249,110 +280,116 @@ def greedy_decode(log_probs, blank_id, itos, stoi):
249
 
250
  return results
251
 
 
252
  def transliterate_latin_to_arabic(text):
253
  """Transliterate Latin script to Arabic script with Firebase caching"""
254
  if not text.strip():
255
  return ""
256
-
257
  # Check Firebase cache first
258
  cached_result = firebase_cache.get(text, "Latin → Arabic")
259
  if cached_result:
260
  return cached_result
261
-
262
  try:
263
  # Encode input text
264
  src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
265
-
266
  # Generate prediction
267
  with torch.no_grad():
268
  out = model(src)
269
-
270
  # Decode output
271
  decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
272
  result = decoded[0] if decoded else ""
273
-
274
  # Cache the result in Firebase
275
  firebase_cache.set(text, "Latin → Arabic", result)
276
-
277
  return result
278
-
279
  except Exception as e:
280
  return f"Error: {str(e)}"
281
 
 
282
  def transliterate_arabic_to_latin(text):
283
  """Transliterate Arabic script to Latin script (placeholder)"""
284
  return "Arabic to Latin transliteration not implemented yet."
285
 
286
- def transliterate(text):
 
287
  """Main transliteration function"""
288
- return transliterate_latin_to_arabic(text.lower())
 
 
 
289
 
290
 
291
- def save_correction(input_text, corrected_output):
292
  """Save user correction to Firebase"""
293
- if firebase_cache.update_correction(input_text, corrected_output):
 
294
  return "Correction saved to the database! Thank you for improving the model."
295
  else:
296
  return "Could not save correction to databse."
297
 
 
298
  # Arabic keyboard layout
299
- arabic_keys = [
300
- ['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'],
301
- ['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'],
302
- ['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'],
303
- ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']
304
- ]
305
 
306
  # Create Gradio interface
307
  def create_interface():
308
- with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo:
309
- gr.Markdown(
310
- """
311
  # Darija Transliterator
312
  Convert between Latin script and Arabic script for Moroccan Darija
313
 
314
  **Firebase-Powered**: Persistent caching across sessions
315
  **Arabic Keyboard**: Built-in Arabic keyboard for corrections
316
  **Real-time Stats**: Live usage analytics
317
- """
318
- )
319
-
320
  # Stats section
321
  with gr.Row():
322
  stats_btn = gr.Button("Show Statistics", variant="secondary")
323
- stats_display = gr.Textbox(
324
- label="Firebase Statistics",
325
- interactive=False,
326
- visible=False,
327
- lines=5
328
- )
329
-
330
  with gr.Row():
331
  with gr.Column(scale=1):
332
-
 
 
 
 
333
  input_text = gr.Textbox(
334
  placeholder="Enter text to transliterate...",
335
  label="Input Text",
336
  lines=4,
337
- max_lines=10
338
- )
339
-
340
  with gr.Row():
341
  clear_btn = gr.Button("Clear", variant="secondary")
342
- translate_btn = gr.Button("Transliterate", variant="primary")
343
-
 
344
  with gr.Column(scale=1):
345
- output_text = gr.Textbox(
346
- label="Output",
347
- lines=4,
348
- max_lines=10,
349
- interactive=True
350
- )
351
-
352
  # Arabic Keyboard
353
  gr.Markdown("### Arabic Keyboard")
354
  gr.Markdown("*Click letters to edit the output text above*")
355
-
356
  with gr.Group():
357
  for row in arabic_keys:
358
  with gr.Row():
@@ -360,114 +397,95 @@ def create_interface():
360
  btn = gr.Button(char, size="sm", scale=1)
361
  btn.click(
362
  fn=None,
363
- js=f"(output_text) => output_text + '{char}'",
 
364
  inputs=[output_text],
365
  outputs=[output_text],
366
  show_progress=False,
367
- queue=False
368
- )
369
-
370
  with gr.Row():
371
  space_btn = gr.Button("Space", size="sm", scale=2)
372
- backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2)
373
- clear_output_btn = gr.Button("Clear Output", size="sm", scale=2)
374
-
 
 
 
 
375
  # Correction system
376
  with gr.Group():
377
  gr.Markdown("### Correction System")
378
- correction_status = gr.Textbox(
379
- label="Status",
380
- interactive=False,
381
- visible=False
382
- )
383
- save_correction_btn = gr.Button("Save Correction", variant="secondary")
384
-
385
  # Keyboard utility buttons
386
- space_btn.click(
387
- fn=None,
388
- js="(output_text) => output_text + ' '",
389
- inputs=[output_text],
390
- outputs=[output_text],
391
- show_progress=False,
392
- queue=False
393
- )
394
-
395
- backspace_btn.click(
396
- fn=None,
397
- js="(output_text) => output_text.slice(0, -1)",
398
- inputs=[output_text],
399
- outputs=[output_text],
400
- show_progress=False,
401
- queue=False
402
- )
403
-
404
- clear_output_btn.click(
405
- fn=None,
406
- js="() => ''",
407
- outputs=[output_text],
408
- show_progress=False,
409
- queue=False
410
- )
411
-
412
  # Stats button
413
- stats_btn.click(
414
- fn=firebase_cache.get_stats,
415
- outputs=[stats_display]
416
- ).then(
417
- fn=lambda: gr.update(visible=True),
418
- outputs=[stats_display]
419
- )
420
-
421
  # Example inputs
422
  gr.Markdown("### Examples")
423
- examples = [
424
- ["makay3nich bli katkhdam bzaf", "Latin → Arabic"],
425
- ["rah bayn dkchi li katdir kolchi 3ay9 bik", "Latin → Arabic"],
426
- ["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"],
427
- ["ghadi temchi f lkhedma mzyan", "Latin → Arabic"]
428
- ]
429
-
430
- gr.Examples(
431
- examples=examples,
432
- inputs=[input_text],
433
- outputs=output_text,
434
- fn=transliterate,
435
- cache_examples=False
436
- )
437
-
438
  # Event handlers
439
- translate_btn.click(
440
- fn=transliterate,
441
- inputs=[input_text],
442
- outputs=output_text
443
- ).then(
444
- fn=lambda: gr.update(visible=True),
445
- outputs=[correction_status]
446
- )
447
-
448
- clear_btn.click(
449
- fn=lambda: ("", ""),
450
- outputs=[input_text, output_text]
451
- )
452
-
453
- input_text.submit(
454
- fn=transliterate,
455
- inputs=[input_text],
456
- outputs=output_text
457
- )
458
-
459
- save_correction_btn.click(
460
- fn=save_correction,
461
- inputs=[input_text, output_text],
462
- outputs=[correction_status]
463
- ).then(
464
- fn=lambda: gr.update(visible=True),
465
- outputs=[correction_status]
466
- )
467
-
468
  # Information
469
- gr.Markdown(
470
- """
471
  ### About
472
  This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
473
 
@@ -482,12 +500,12 @@ def create_interface():
482
  1. Use the Arabic keyboard to correct any wrong translations
483
  2. Click "Save Correction" to store your improvement
484
  3. Your corrections help train better models for everyone!
485
- """
486
- )
487
-
488
  return demo
489
 
 
490
  # Launch the app
491
  if __name__ == "__main__":
492
  demo = create_interface()
493
- demo.launch(share=True)
 
9
  import firebase_admin
10
  from firebase_admin import credentials, firestore
11
 
12
+
13
  # Define the model architecture
14
  class CTCTransliterator(nn.Module):
15
+
16
+ def __init__(self,
17
+ input_dim,
18
+ hidden_dim,
19
+ output_dim,
20
+ num_layers=3,
21
+ dropout=0.3,
22
+ upsample_factor=3):
23
  super().__init__()
24
  self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0)
25
+ self.lstm = nn.LSTM(hidden_dim,
26
+ hidden_dim,
27
+ num_layers=num_layers,
28
+ bidirectional=True,
29
+ dropout=dropout)
30
  self.layer_norm = nn.LayerNorm(hidden_dim * 2)
31
  self.dropout = nn.Dropout(dropout)
32
  self.upsample_factor = upsample_factor
 
41
 
42
  # (seq_len, batch, hidden) → (batch, hidden, seq_len)
43
  x = x.permute(1, 2, 0)
44
+ x = F.interpolate(x,
45
+ scale_factor=self.upsample_factor,
46
+ mode='linear',
47
+ align_corners=False)
48
  # → (batch, hidden, seq_len*upsample_factor)
49
+ x = x.permute(2, 0,
50
+ 1) # back to (seq_len*upsample_factor, batch, hidden)
51
 
52
  x = self.fc(x)
53
  x = x.log_softmax(dim=2)
54
  return x
55
 
56
+
57
  # Firebase Cache System
58
  class FirebaseCache:
59
+
60
  def __init__(self):
61
  self.db = None
62
  self.init_firebase()
63
+
64
  def init_firebase(self):
65
  """Initialize Firebase connection"""
66
  try:
 
70
  if os.getenv('FIREBASE_CREDENTIALS'):
71
  # Parse credentials from environment variable
72
  import base64
73
+ cred_data = json.loads(
74
+ base64.b64decode(
75
+ os.getenv('FIREBASE_CREDENTIALS')).decode())
76
  cred = credentials.Certificate(cred_data)
77
  elif os.path.exists('firebase-credentials.json'):
78
  # For local development
79
  cred = credentials.Certificate('firebase-credentials.json')
80
  else:
81
+ print(
82
+ "No Firebase credentials found. Using local cache fallback."
83
+ )
84
  return
85
+
86
  firebase_admin.initialize_app(cred)
87
  self.db = firestore.client()
88
  print("Firebase initialized successfully!")
89
  else:
90
  self.db = firestore.client()
91
+
92
  except Exception as e:
93
  print(f"Firebase initialization failed: {e}")
94
  print("Falling back to local cache mode")
95
  self.db = None
96
+
97
+ def _create_cache_key(self, input_text, direction):
98
  """Create a safe document key for Firestore"""
99
  import hashlib
100
  # Create hash to handle special characters and length limits
101
+ key = f"{input_text}_{direction}"
102
  return hashlib.md5(key.encode()).hexdigest()
103
+
104
+ def get(self, input_text, direction):
105
  """Get cached translation from Firebase"""
106
  if not self.db:
107
  return None
108
+
109
  try:
110
+ doc_key = self._create_cache_key(input_text, direction)
111
  doc = self.db.collection('translations').document(doc_key).get()
112
+
113
  if doc.exists:
114
  data = doc.to_dict()
115
  # Update usage count
116
  self.db.collection('translations').document(doc_key).update({
117
+ 'usage_count':
118
+ data.get('usage_count', 0) + 1,
119
+ 'last_used':
120
+ datetime.now()
121
  })
122
  print(f"Cache hit: {input_text}")
123
  return data.get('output', '')
124
+
125
  return None
126
+
127
  except Exception as e:
128
  print(f"Cache read error: {e}")
129
  return None
130
+
131
+ def set(self, input_text, direction, output):
132
  """Store translation in Firebase"""
133
  if not self.db:
134
  return False
135
+
136
  try:
137
+ doc_key = self._create_cache_key(input_text, direction)
138
  doc_data = {
139
  'input': input_text,
140
+ 'direction': direction,
141
  'output': output,
142
  'corrected_output': '',
143
  'timestamp': datetime.now(),
144
  'last_used': datetime.now(),
145
  'usage_count': 1
146
  }
147
+
148
  self.db.collection('translations').document(doc_key).set(doc_data)
149
  print(f"Cached: {input_text} → {output}")
150
  return True
151
+
152
  except Exception as e:
153
  print(f"Cache write error: {e}")
154
  return False
155
+
156
+ def update_correction(self, input_text, direction, corrected_output):
157
  """Update translation with user correction"""
158
  if not self.db:
159
  return False
160
+
161
  try:
162
+ doc_key = self._create_cache_key(input_text, direction)
163
  self.db.collection('translations').document(doc_key).update({
164
+ 'corrected_output':
165
+ corrected_output,
166
+ 'correction_timestamp':
167
+ datetime.now()
168
  })
169
  print(f"Correction saved: {input_text} → {corrected_output}")
170
  return True
171
+
172
  except Exception as e:
173
  print(f"Correction save error: {e}")
174
  return False
175
+
176
  def get_stats(self):
177
  """Get cache statistics"""
178
  if not self.db:
179
  return "Firebase not connected"
180
+
181
  try:
182
  docs = self.db.collection('translations').get()
183
  total = len(docs)
184
+
185
  corrected = 0
186
  total_usage = 0
187
+
188
  for doc in docs:
189
  data = doc.to_dict()
190
  if data.get('corrected_output'):
191
  corrected += 1
192
  total_usage += data.get('usage_count', 0)
193
+
194
  return f"""
195
  Cache Statistics:
196
  • Total translations: {total}
 
198
  • Total usage count: {total_usage}
199
  • Average usage: {total_usage/total if total > 0 else 0:.1f} per translation
200
  """.strip()
201
+
202
  except Exception as e:
203
  return f"Error getting stats: {e}"
204
 
205
+
206
  # Load vocabularies and model
207
  def load_model_and_vocabs():
208
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
+
210
  # Load vocabularies
211
  with open('latin_stoi.json', 'r', encoding='utf-8') as f:
212
  latin_stoi = json.load(f)
213
  with open('latin_itos.json', 'r', encoding='utf-8') as f:
214
  latin_itos = json.load(f)
215
+
216
  with open('arabic_stoi.json', 'r', encoding='utf-8') as f:
217
  arabic_stoi = json.load(f)
218
  with open('arabic_itos.json', 'r', encoding='utf-8') as f:
219
+ arabic_itos = json.load(f)
220
+
221
  # Initialize model
222
+ model = CTCTransliterator(len(latin_stoi),
223
+ 256,
224
+ len(arabic_stoi),
225
+ num_layers=3,
226
+ dropout=0.3,
227
+ upsample_factor=2).to(device)
228
+
 
 
229
  # Load trained weights
230
+ model.load_state_dict(
231
+ torch.load('best_model.pth', map_location=device, weights_only=False))
232
  model.eval()
233
+
234
+ blank_id = arabic_stoi.get('<blank>', len(arabic_itos) - 1)
235
  return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device
236
 
237
+
238
  # Load everything at startup
239
+ model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs(
240
+ )
241
  firebase_cache = FirebaseCache()
242
 
243
+
244
  def encode_text(text, vocab):
245
  """Encode text using vocabulary"""
246
+ return torch.tensor([vocab.get(ch, 0) for ch in text.strip()],
247
+ dtype=torch.long)
248
+
249
 
250
  def greedy_decode(log_probs, blank_id, itos, stoi):
251
  """
252
  Decode CTC outputs using greedy decoding.
253
  """
254
+ eos_id = stoi.get('<eos>', len(stoi) - 2)
255
  preds = log_probs.argmax(2).T.cpu().numpy() # (B, T)
256
  results = []
257
  raw_results = []
258
  print(eos_id, blank_id)
259
  print(stoi)
260
  print(type(blank_id))
261
+ print(stoi.get('<eos>', 0))
262
  for i, pred in enumerate(preds):
263
  prev = None
264
  decoded = []
 
270
  break
271
  # CTC collapse: skip blanks and repeated characters
272
  if p != blank_id and p != prev:
273
+ decoded.append(itos[str(p)])
274
  prev = p
275
  raw_result.append(itos[str(p)])
276
 
 
280
 
281
  return results
282
 
283
+
284
  def transliterate_latin_to_arabic(text):
285
  """Transliterate Latin script to Arabic script with Firebase caching"""
286
  if not text.strip():
287
  return ""
288
+
289
  # Check Firebase cache first
290
  cached_result = firebase_cache.get(text, "Latin → Arabic")
291
  if cached_result:
292
  return cached_result
293
+
294
  try:
295
  # Encode input text
296
  src = encode_text(text, latin_stoi).unsqueeze(1).to(device)
297
+
298
  # Generate prediction
299
  with torch.no_grad():
300
  out = model(src)
301
+
302
  # Decode output
303
  decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi)
304
  result = decoded[0] if decoded else ""
305
+
306
  # Cache the result in Firebase
307
  firebase_cache.set(text, "Latin → Arabic", result)
308
+
309
  return result
310
+
311
  except Exception as e:
312
  return f"Error: {str(e)}"
313
 
314
+
315
  def transliterate_arabic_to_latin(text):
316
  """Transliterate Arabic script to Latin script (placeholder)"""
317
  return "Arabic to Latin transliteration not implemented yet."
318
 
319
+
320
+ def transliterate(text, direction):
321
  """Main transliteration function"""
322
+ if direction == "Latin → Arabic":
323
+ return transliterate_latin_to_arabic(text.lower())
324
+ else:
325
+ return transliterate_arabic_to_latin(text)
326
 
327
 
328
+ def save_correction(input_text, direction, corrected_output):
329
  """Save user correction to Firebase"""
330
+ if firebase_cache.update_correction(input_text, direction,
331
+ corrected_output):
332
  return "Correction saved to the database! Thank you for improving the model."
333
  else:
334
  return "Could not save correction to databse."
335
 
336
+
337
  # Arabic keyboard layout
338
+ arabic_keys = [['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'],
339
+ ['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'],
340
+ ['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'],
341
+ ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠']]
342
+
 
343
 
344
  # Create Gradio interface
345
  def create_interface():
346
+ with gr.Blocks(title="Darija Transliterator",
347
+ theme=gr.themes.Soft()) as demo:
348
+ gr.Markdown("""
349
  # Darija Transliterator
350
  Convert between Latin script and Arabic script for Moroccan Darija
351
 
352
  **Firebase-Powered**: Persistent caching across sessions
353
  **Arabic Keyboard**: Built-in Arabic keyboard for corrections
354
  **Real-time Stats**: Live usage analytics
355
+ """)
356
+
 
357
  # Stats section
358
  with gr.Row():
359
  stats_btn = gr.Button("Show Statistics", variant="secondary")
360
+ stats_display = gr.Textbox(label="Firebase Statistics",
361
+ interactive=False,
362
+ visible=False,
363
+ lines=5)
364
+
 
 
365
  with gr.Row():
366
  with gr.Column(scale=1):
367
+ direction = gr.Radio(
368
+ choices=["Latin → Arabic"],
369
+ value="Latin → Arabic",
370
+ label="Translation Direction")
371
+
372
  input_text = gr.Textbox(
373
  placeholder="Enter text to transliterate...",
374
  label="Input Text",
375
  lines=4,
376
+ max_lines=10)
377
+
 
378
  with gr.Row():
379
  clear_btn = gr.Button("Clear", variant="secondary")
380
+ translate_btn = gr.Button("Transliterate",
381
+ variant="primary")
382
+
383
  with gr.Column(scale=1):
384
+ output_text = gr.Textbox(label="Output",
385
+ lines=4,
386
+ max_lines=10,
387
+ interactive=True)
388
+
 
 
389
  # Arabic Keyboard
390
  gr.Markdown("### Arabic Keyboard")
391
  gr.Markdown("*Click letters to edit the output text above*")
392
+
393
  with gr.Group():
394
  for row in arabic_keys:
395
  with gr.Row():
 
397
  btn = gr.Button(char, size="sm", scale=1)
398
  btn.click(
399
  fn=None,
400
+ js=
401
+ f"(output_text) => output_text + '{char}'",
402
  inputs=[output_text],
403
  outputs=[output_text],
404
  show_progress=False,
405
+ queue=False)
406
+
 
407
  with gr.Row():
408
  space_btn = gr.Button("Space", size="sm", scale=2)
409
+ backspace_btn = gr.Button("⌫ Backspace",
410
+ size="sm",
411
+ scale=2)
412
+ clear_output_btn = gr.Button("Clear Output",
413
+ size="sm",
414
+ scale=2)
415
+
416
  # Correction system
417
  with gr.Group():
418
  gr.Markdown("### Correction System")
419
+ correction_status = gr.Textbox(label="Status",
420
+ interactive=False,
421
+ visible=False)
422
+ save_correction_btn = gr.Button("Save Correction",
423
+ variant="secondary")
424
+
 
425
  # Keyboard utility buttons
426
+ space_btn.click(fn=None,
427
+ js="(output_text) => output_text + ' '",
428
+ inputs=[output_text],
429
+ outputs=[output_text],
430
+ show_progress=False,
431
+ queue=False)
432
+
433
+ backspace_btn.click(fn=None,
434
+ js="(output_text) => output_text.slice(0, -1)",
435
+ inputs=[output_text],
436
+ outputs=[output_text],
437
+ show_progress=False,
438
+ queue=False)
439
+
440
+ clear_output_btn.click(fn=None,
441
+ js="() => ''",
442
+ outputs=[output_text],
443
+ show_progress=False,
444
+ queue=False)
445
+
 
 
 
 
 
 
446
  # Stats button
447
+ stats_btn.click(fn=firebase_cache.get_stats,
448
+ outputs=[stats_display
449
+ ]).then(fn=lambda: gr.update(visible=True),
450
+ outputs=[stats_display])
451
+
 
 
 
452
  # Example inputs
453
  gr.Markdown("### Examples")
454
+ examples = [["makay3nich bli katkhdam bzaf", "Latin → Arabic"],
455
+ [
456
+ "rah bayn dkchi li katdir kolchi 3ay9 bik",
457
+ "Latin → Arabic"
458
+ ],
459
+ ["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"],
460
+ ["ghadi temchi f lkhedma mzyan", "Latin → Arabic"]]
461
+
462
+ gr.Examples(examples=examples,
463
+ inputs=[input_text, direction],
464
+ outputs=output_text,
465
+ fn=transliterate,
466
+ cache_examples=False)
467
+
 
468
  # Event handlers
469
+ translate_btn.click(fn=transliterate,
470
+ inputs=[input_text, direction],
471
+ outputs=output_text).then(
472
+ fn=lambda: gr.update(visible=True),
473
+ outputs=[correction_status])
474
+
475
+ clear_btn.click(fn=lambda: ("", ""), outputs=[input_text, output_text])
476
+
477
+ input_text.submit(fn=transliterate,
478
+ inputs=[input_text, direction],
479
+ outputs=output_text)
480
+
481
+ save_correction_btn.click(fn=save_correction,
482
+ inputs=[input_text, direction, output_text],
483
+ outputs=[correction_status]).then(
484
+ fn=lambda: gr.update(visible=True),
485
+ outputs=[correction_status])
486
+
 
 
 
 
 
 
 
 
 
 
 
487
  # Information
488
+ gr.Markdown("""
 
489
  ### About
490
  This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network.
491
 
 
500
  1. Use the Arabic keyboard to correct any wrong translations
501
  2. Click "Save Correction" to store your improvement
502
  3. Your corrections help train better models for everyone!
503
+ """)
504
+
 
505
  return demo
506
 
507
+
508
  # Launch the app
509
  if __name__ == "__main__":
510
  demo = create_interface()
511
+ demo.launch(share=True)