SondosM commited on
Commit
666de0b
·
verified ·
1 Parent(s): ea286f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -25
app.py CHANGED
@@ -3,7 +3,6 @@ import re
3
  import json
4
  import logging
5
  import warnings
6
- import urllib.request
7
  from pathlib import Path
8
  from typing import List, Dict, Optional, Tuple
9
  from dataclasses import dataclass, field
@@ -26,14 +25,30 @@ logger = logging.getLogger("ArabicSignNLP")
26
 
27
  # ----- Project Configuration -----
28
  class Config:
 
 
29
  CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
 
 
30
  KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
 
 
31
  SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
 
 
32
  EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
 
 
33
  SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
 
 
34
  INCLUDE_PREPOSITION_WORDS: bool = False
 
 
35
  API_HOST: str = "0.0.0.0"
36
- API_PORT: int = 7860
 
 
37
  CSV_LABEL_COLUMN: str = "label"
38
 
39
 
@@ -223,19 +238,11 @@ class SemanticSignMatcher:
223
 
224
  def _load_database(self, csv_path: str, label_column: str):
225
  if not os.path.exists(csv_path):
226
- logger.info("CSV not found locally. Downloading from Hugging Face...")
227
- url = "https://huggingface.co/spaces/SondosM/avatarAPI/resolve/main/arabic_sign_lang_features.csv"
228
- try:
229
- urllib.request.urlretrieve(url, csv_path)
230
- logger.info("CSV downloaded successfully.")
231
- except Exception as e:
232
- logger.warning(f"Failed to download CSV: {e}. No word signs loaded.")
233
- return
234
-
235
  df = pd.read_csv(csv_path, low_memory=False)
236
  if label_column not in df.columns:
237
  raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
238
-
239
  all_labels = df[label_column].dropna().unique().tolist()
240
  arabic_labels = [
241
  str(l) for l in all_labels
@@ -268,28 +275,26 @@ class SemanticSignMatcher:
268
  return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
269
  norm_word = self._normalize_label(word_text)
270
  norm_lemma = self._normalize_label(lemma) if lemma else ""
271
-
272
  if norm_word in self._word_signs:
273
  idx = self._word_signs.index(norm_word)
274
  return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
275
-
276
  if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
277
  idx = self._word_signs.index(norm_lemma)
278
  return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
279
-
280
  if self._model is None or self._sign_embeddings is None:
281
  return SignMatch(False, "", 0.0, "none")
282
-
283
  candidates = list({norm_word, norm_lemma} - {""})
284
  embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
285
  scores = util.cos_sim(embs, self._sign_embeddings)
286
  best_val = float(scores.max())
287
  best_idx = int(scores.argmax() % len(self._word_signs))
288
-
289
  if best_val >= self.threshold:
290
  return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
291
  return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
292
 
 
 
 
293
  @property
294
  def available_signs(self) -> List[str]:
295
  return self._raw_labels.copy()
@@ -364,7 +369,6 @@ class BlenderSequenceWriter:
364
  missing_files = self._check_missing_keypoints(plan)
365
  with open(self.output_path, "w", encoding="utf-8") as f:
366
  f.write("\n".join(identifiers))
367
-
368
  sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
369
  letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
370
  return {
@@ -440,8 +444,8 @@ logger.info("All components ready.")
440
 
441
  # ----- FastAPI App -----
442
  class TranslateRequest(BaseModel):
443
- text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000)
444
- save_sequence: bool = Field(default=False)
445
 
446
 
447
  class StepDetail(BaseModel):
@@ -463,7 +467,11 @@ class TranslateResponse(BaseModel):
463
  detailed_plan: List[StepDetail]
464
 
465
 
466
- app = FastAPI(title="Arabic Sign Language NLP API")
 
 
 
 
467
 
468
  app.add_middleware(
469
  CORSMiddleware,
@@ -488,10 +496,8 @@ def translate_post(request: TranslateRequest):
488
  result = translator.translate(request.text, save_to_file=request.save_sequence)
489
  except Exception as e:
490
  raise HTTPException(status_code=500, detail=str(e))
491
-
492
  if result["status"] == "error":
493
  raise HTTPException(status_code=422, detail=result["message"])
494
-
495
  return TranslateResponse(
496
  status=result["status"],
497
  input_text=request.text,
@@ -501,10 +507,36 @@ def translate_post(request: TranslateRequest):
501
  letter_count=result.get("letter_count", 0),
502
  missing_keypoint_files=result.get("missing_keypoint_files", []),
503
  detailed_plan=[
504
- StepDetail(**s) for s in result.get("detailed_plan", [])
 
 
505
  ],
506
  )
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  if __name__ == "__main__":
509
  import uvicorn
510
- uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)
 
3
  import json
4
  import logging
5
  import warnings
 
6
  from pathlib import Path
7
  from typing import List, Dict, Optional, Tuple
8
  from dataclasses import dataclass, field
 
25
 
26
  # ----- Project Configuration -----
27
  class Config:
28
+ # Path to your CSV dataset containing sign labels
29
+ # On HF Spaces, upload your CSV to the repo and set the path here
30
  CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
31
+
32
+ # Folder where .npy keypoint files are stored (optional on HF Spaces)
33
  KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
34
+
35
+ # Output file path for Blender sequence
36
  SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
37
+
38
+ # AraBERT model for Arabic semantic understanding
39
  EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
40
+
41
+ # Similarity threshold for sign matching
42
  SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
43
+
44
+ # Include prepositions in signing
45
  INCLUDE_PREPOSITION_WORDS: bool = False
46
+
47
+ # FastAPI server settings
48
  API_HOST: str = "0.0.0.0"
49
+ API_PORT: int = 7860 # HF Spaces uses port 7860
50
+
51
+ # Column name in your CSV that contains the sign labels
52
  CSV_LABEL_COLUMN: str = "label"
53
 
54
 
 
238
 
239
  def _load_database(self, csv_path: str, label_column: str):
240
  if not os.path.exists(csv_path):
241
+ logger.warning(f"CSV not found at {csv_path}. No word signs loaded.")
242
+ return
 
 
 
 
 
 
 
243
  df = pd.read_csv(csv_path, low_memory=False)
244
  if label_column not in df.columns:
245
  raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
 
246
  all_labels = df[label_column].dropna().unique().tolist()
247
  arabic_labels = [
248
  str(l) for l in all_labels
 
275
  return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
276
  norm_word = self._normalize_label(word_text)
277
  norm_lemma = self._normalize_label(lemma) if lemma else ""
 
278
  if norm_word in self._word_signs:
279
  idx = self._word_signs.index(norm_word)
280
  return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
 
281
  if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
282
  idx = self._word_signs.index(norm_lemma)
283
  return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
 
284
  if self._model is None or self._sign_embeddings is None:
285
  return SignMatch(False, "", 0.0, "none")
 
286
  candidates = list({norm_word, norm_lemma} - {""})
287
  embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
288
  scores = util.cos_sim(embs, self._sign_embeddings)
289
  best_val = float(scores.max())
290
  best_idx = int(scores.argmax() % len(self._word_signs))
 
291
  if best_val >= self.threshold:
292
  return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
293
  return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
294
 
295
+ def letter_to_label(self, arabic_letter: str) -> Optional[str]:
296
+ return ARABIC_LETTER_TO_LABEL.get(arabic_letter)
297
+
298
  @property
299
  def available_signs(self) -> List[str]:
300
  return self._raw_labels.copy()
 
369
  missing_files = self._check_missing_keypoints(plan)
370
  with open(self.output_path, "w", encoding="utf-8") as f:
371
  f.write("\n".join(identifiers))
 
372
  sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
373
  letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
374
  return {
 
444
 
445
  # ----- FastAPI App -----
446
  class TranslateRequest(BaseModel):
447
+ text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000, examples=["انا عايز اروح المدرسة"])
448
+ save_sequence: bool = Field(default=False, description="Save sequence file to /tmp/sequence.txt")
449
 
450
 
451
  class StepDetail(BaseModel):
 
467
  detailed_plan: List[StepDetail]
468
 
469
 
470
+ app = FastAPI(
471
+ title="Arabic Sign Language NLP API",
472
+ description="Translates Arabic text (Fus-ha and Ammiya) into sign animation sequences.",
473
+ version="1.0.0",
474
+ )
475
 
476
  app.add_middleware(
477
  CORSMiddleware,
 
496
  result = translator.translate(request.text, save_to_file=request.save_sequence)
497
  except Exception as e:
498
  raise HTTPException(status_code=500, detail=str(e))
 
499
  if result["status"] == "error":
500
  raise HTTPException(status_code=422, detail=result["message"])
 
501
  return TranslateResponse(
502
  status=result["status"],
503
  input_text=request.text,
 
507
  letter_count=result.get("letter_count", 0),
508
  missing_keypoint_files=result.get("missing_keypoint_files", []),
509
  detailed_plan=[
510
+ StepDetail(type=s["type"], identifier=s["identifier"], source_word=s["source_word"],
511
+ confidence=s["confidence"], method=s["method"])
512
+ for s in result.get("detailed_plan", [])
513
  ],
514
  )
515
 
516
+
517
+ @app.get("/translate")
518
+ def translate_get(
519
+ text: str = Query(description="Arabic text to translate"),
520
+ save_sequence: bool = Query(default=False),
521
+ ):
522
+ return translate_post(TranslateRequest(text=text, save_sequence=save_sequence))
523
+
524
+
525
+ @app.get("/signs")
526
+ def list_signs():
527
+ return {"total": len(sign_matcher.available_signs), "signs": sign_matcher.available_signs}
528
+
529
+
530
+ @app.get("/sequence-file")
531
+ def read_sequence_file():
532
+ path = Config.SEQUENCE_OUTPUT_PATH
533
+ if not os.path.exists(path):
534
+ raise HTTPException(status_code=404, detail="Sequence file not found. Run a translation first.")
535
+ with open(path, "r", encoding="utf-8") as f:
536
+ lines = [line.strip() for line in f.readlines() if line.strip()]
537
+ return {"file_path": path, "sequence": lines, "count": len(lines)}
538
+
539
+
540
  if __name__ == "__main__":
541
  import uvicorn
542
+ uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)