SondosM commited on
Commit
ea286f0
·
verified ·
1 Parent(s): ea1a0a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -69
app.py CHANGED
@@ -3,6 +3,7 @@ import re
3
  import json
4
  import logging
5
  import warnings
 
6
  from pathlib import Path
7
  from typing import List, Dict, Optional, Tuple
8
  from dataclasses import dataclass, field
@@ -25,30 +26,14 @@ logger = logging.getLogger("ArabicSignNLP")
25
 
26
  # ----- Project Configuration -----
27
  class Config:
28
- # Path to your CSV dataset containing sign labels
29
- # On HF Spaces, upload your CSV to the repo and set the path here
30
  CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
31
-
32
- # Folder where .npy keypoint files are stored (optional on HF Spaces)
33
  KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
34
-
35
- # Output file path for Blender sequence
36
  SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
37
-
38
- # AraBERT model for Arabic semantic understanding
39
  EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
40
-
41
- # Similarity threshold for sign matching
42
  SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
43
-
44
- # Include prepositions in signing
45
  INCLUDE_PREPOSITION_WORDS: bool = False
46
-
47
- # FastAPI server settings
48
  API_HOST: str = "0.0.0.0"
49
- API_PORT: int = 7860 # HF Spaces uses port 7860
50
-
51
- # Column name in your CSV that contains the sign labels
52
  CSV_LABEL_COLUMN: str = "label"
53
 
54
 
@@ -236,29 +221,30 @@ class SemanticSignMatcher:
236
  return self._normalizer.normalize_label(label)
237
  return label
238
 
239
- def _load_database(self, csv_path: str, label_column: str):
240
  if not os.path.exists(csv_path):
241
  logger.info("CSV not found locally. Downloading from Hugging Face...")
242
- import urllib.request
243
  url = "https://huggingface.co/spaces/SondosM/avatarAPI/resolve/main/arabic_sign_lang_features.csv"
244
  try:
245
  urllib.request.urlretrieve(url, csv_path)
246
  logger.info("CSV downloaded successfully.")
247
  except Exception as e:
248
  logger.warning(f"Failed to download CSV: {e}. No word signs loaded.")
249
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- df = pd.read_csv(csv_path, low_memory=False)
252
- if label_column not in df.columns:
253
- raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
254
- all_labels = df[label_column].dropna().unique().tolist()
255
- arabic_labels = [
256
- str(l) for l in all_labels
257
- if isinstance(l, str) and any("\u0600" <= c <= "\u06ff" for c in str(l))
258
- ]
259
- self._raw_labels = arabic_labels
260
- self._word_signs = arabic_labels.copy()
261
- logger.info(f"Database: {len(arabic_labels)} Arabic word labels loaded.")
262
  def _finalize_labels(self):
263
  if self._normalizer and self._raw_labels:
264
  self._word_signs = [self._normalize_label(l) for l in self._raw_labels]
@@ -282,26 +268,28 @@ class SemanticSignMatcher:
282
  return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
283
  norm_word = self._normalize_label(word_text)
284
  norm_lemma = self._normalize_label(lemma) if lemma else ""
 
285
  if norm_word in self._word_signs:
286
  idx = self._word_signs.index(norm_word)
287
  return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
 
288
  if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
289
  idx = self._word_signs.index(norm_lemma)
290
  return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
 
291
  if self._model is None or self._sign_embeddings is None:
292
  return SignMatch(False, "", 0.0, "none")
 
293
  candidates = list({norm_word, norm_lemma} - {""})
294
  embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
295
  scores = util.cos_sim(embs, self._sign_embeddings)
296
  best_val = float(scores.max())
297
  best_idx = int(scores.argmax() % len(self._word_signs))
 
298
  if best_val >= self.threshold:
299
  return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
300
  return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
301
 
302
- def letter_to_label(self, arabic_letter: str) -> Optional[str]:
303
- return ARABIC_LETTER_TO_LABEL.get(arabic_letter)
304
-
305
  @property
306
  def available_signs(self) -> List[str]:
307
  return self._raw_labels.copy()
@@ -376,6 +364,7 @@ class BlenderSequenceWriter:
376
  missing_files = self._check_missing_keypoints(plan)
377
  with open(self.output_path, "w", encoding="utf-8") as f:
378
  f.write("\n".join(identifiers))
 
379
  sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
380
  letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
381
  return {
@@ -451,8 +440,8 @@ logger.info("All components ready.")
451
 
452
  # ----- FastAPI App -----
453
  class TranslateRequest(BaseModel):
454
- text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000, examples=["انا عايز اروح المدرسة"])
455
- save_sequence: bool = Field(default=False, description="Save sequence file to /tmp/sequence.txt")
456
 
457
 
458
  class StepDetail(BaseModel):
@@ -474,11 +463,7 @@ class TranslateResponse(BaseModel):
474
  detailed_plan: List[StepDetail]
475
 
476
 
477
- app = FastAPI(
478
- title="Arabic Sign Language NLP API",
479
- description="Translates Arabic text (Fus-ha and Ammiya) into sign animation sequences.",
480
- version="1.0.0",
481
- )
482
 
483
  app.add_middleware(
484
  CORSMiddleware,
@@ -503,8 +488,10 @@ def translate_post(request: TranslateRequest):
503
  result = translator.translate(request.text, save_to_file=request.save_sequence)
504
  except Exception as e:
505
  raise HTTPException(status_code=500, detail=str(e))
 
506
  if result["status"] == "error":
507
  raise HTTPException(status_code=422, detail=result["message"])
 
508
  return TranslateResponse(
509
  status=result["status"],
510
  input_text=request.text,
@@ -514,36 +501,10 @@ def translate_post(request: TranslateRequest):
514
  letter_count=result.get("letter_count", 0),
515
  missing_keypoint_files=result.get("missing_keypoint_files", []),
516
  detailed_plan=[
517
- StepDetail(type=s["type"], identifier=s["identifier"], source_word=s["source_word"],
518
- confidence=s["confidence"], method=s["method"])
519
- for s in result.get("detailed_plan", [])
520
  ],
521
  )
522
 
523
-
524
- @app.get("/translate")
525
- def translate_get(
526
- text: str = Query(description="Arabic text to translate"),
527
- save_sequence: bool = Query(default=False),
528
- ):
529
- return translate_post(TranslateRequest(text=text, save_sequence=save_sequence))
530
-
531
-
532
- @app.get("/signs")
533
- def list_signs():
534
- return {"total": len(sign_matcher.available_signs), "signs": sign_matcher.available_signs}
535
-
536
-
537
- @app.get("/sequence-file")
538
- def read_sequence_file():
539
- path = Config.SEQUENCE_OUTPUT_PATH
540
- if not os.path.exists(path):
541
- raise HTTPException(status_code=404, detail="Sequence file not found. Run a translation first.")
542
- with open(path, "r", encoding="utf-8") as f:
543
- lines = [line.strip() for line in f.readlines() if line.strip()]
544
- return {"file_path": path, "sequence": lines, "count": len(lines)}
545
-
546
-
547
  if __name__ == "__main__":
548
  import uvicorn
549
- uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)
 
3
  import json
4
  import logging
5
  import warnings
6
+ import urllib.request
7
  from pathlib import Path
8
  from typing import List, Dict, Optional, Tuple
9
  from dataclasses import dataclass, field
 
26
 
27
  # ----- Project Configuration -----
28
  class Config:
 
 
29
  CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
 
 
30
  KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
 
 
31
  SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
 
 
32
  EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
 
 
33
  SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
 
 
34
  INCLUDE_PREPOSITION_WORDS: bool = False
 
 
35
  API_HOST: str = "0.0.0.0"
36
+ API_PORT: int = 7860
 
 
37
  CSV_LABEL_COLUMN: str = "label"
38
 
39
 
 
221
  return self._normalizer.normalize_label(label)
222
  return label
223
 
224
+ def _load_database(self, csv_path: str, label_column: str):
225
  if not os.path.exists(csv_path):
226
  logger.info("CSV not found locally. Downloading from Hugging Face...")
 
227
  url = "https://huggingface.co/spaces/SondosM/avatarAPI/resolve/main/arabic_sign_lang_features.csv"
228
  try:
229
  urllib.request.urlretrieve(url, csv_path)
230
  logger.info("CSV downloaded successfully.")
231
  except Exception as e:
232
  logger.warning(f"Failed to download CSV: {e}. No word signs loaded.")
233
+ return
234
+
235
+ df = pd.read_csv(csv_path, low_memory=False)
236
+ if label_column not in df.columns:
237
+ raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
238
+
239
+ all_labels = df[label_column].dropna().unique().tolist()
240
+ arabic_labels = [
241
+ str(l) for l in all_labels
242
+ if isinstance(l, str) and any("\u0600" <= c <= "\u06ff" for c in str(l))
243
+ ]
244
+ self._raw_labels = arabic_labels
245
+ self._word_signs = arabic_labels.copy()
246
+ logger.info(f"Database: {len(arabic_labels)} Arabic word labels loaded.")
247
 
 
 
 
 
 
 
 
 
 
 
 
248
  def _finalize_labels(self):
249
  if self._normalizer and self._raw_labels:
250
  self._word_signs = [self._normalize_label(l) for l in self._raw_labels]
 
268
  return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
269
  norm_word = self._normalize_label(word_text)
270
  norm_lemma = self._normalize_label(lemma) if lemma else ""
271
+
272
  if norm_word in self._word_signs:
273
  idx = self._word_signs.index(norm_word)
274
  return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
275
+
276
  if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
277
  idx = self._word_signs.index(norm_lemma)
278
  return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
279
+
280
  if self._model is None or self._sign_embeddings is None:
281
  return SignMatch(False, "", 0.0, "none")
282
+
283
  candidates = list({norm_word, norm_lemma} - {""})
284
  embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
285
  scores = util.cos_sim(embs, self._sign_embeddings)
286
  best_val = float(scores.max())
287
  best_idx = int(scores.argmax() % len(self._word_signs))
288
+
289
  if best_val >= self.threshold:
290
  return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
291
  return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
292
 
 
 
 
293
  @property
294
  def available_signs(self) -> List[str]:
295
  return self._raw_labels.copy()
 
364
  missing_files = self._check_missing_keypoints(plan)
365
  with open(self.output_path, "w", encoding="utf-8") as f:
366
  f.write("\n".join(identifiers))
367
+
368
  sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
369
  letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
370
  return {
 
440
 
441
  # ----- FastAPI App -----
442
  class TranslateRequest(BaseModel):
443
+ text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000)
444
+ save_sequence: bool = Field(default=False)
445
 
446
 
447
  class StepDetail(BaseModel):
 
463
  detailed_plan: List[StepDetail]
464
 
465
 
466
+ app = FastAPI(title="Arabic Sign Language NLP API")
 
 
 
 
467
 
468
  app.add_middleware(
469
  CORSMiddleware,
 
488
  result = translator.translate(request.text, save_to_file=request.save_sequence)
489
  except Exception as e:
490
  raise HTTPException(status_code=500, detail=str(e))
491
+
492
  if result["status"] == "error":
493
  raise HTTPException(status_code=422, detail=result["message"])
494
+
495
  return TranslateResponse(
496
  status=result["status"],
497
  input_text=request.text,
 
501
  letter_count=result.get("letter_count", 0),
502
  missing_keypoint_files=result.get("missing_keypoint_files", []),
503
  detailed_plan=[
504
+ StepDetail(**s) for s in result.get("detailed_plan", [])
 
 
505
  ],
506
  )
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  if __name__ == "__main__":
509
  import uvicorn
510
+ uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)