Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import re
|
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
import warnings
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import List, Dict, Optional, Tuple
|
| 8 |
from dataclasses import dataclass, field
|
|
@@ -25,30 +26,14 @@ logger = logging.getLogger("ArabicSignNLP")
|
|
| 25 |
|
| 26 |
# ----- Project Configuration -----
|
| 27 |
class Config:
|
| 28 |
-
# Path to your CSV dataset containing sign labels
|
| 29 |
-
# On HF Spaces, upload your CSV to the repo and set the path here
|
| 30 |
CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
|
| 31 |
-
|
| 32 |
-
# Folder where .npy keypoint files are stored (optional on HF Spaces)
|
| 33 |
KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
|
| 34 |
-
|
| 35 |
-
# Output file path for Blender sequence
|
| 36 |
SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
|
| 37 |
-
|
| 38 |
-
# AraBERT model for Arabic semantic understanding
|
| 39 |
EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
|
| 40 |
-
|
| 41 |
-
# Similarity threshold for sign matching
|
| 42 |
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
|
| 43 |
-
|
| 44 |
-
# Include prepositions in signing
|
| 45 |
INCLUDE_PREPOSITION_WORDS: bool = False
|
| 46 |
-
|
| 47 |
-
# FastAPI server settings
|
| 48 |
API_HOST: str = "0.0.0.0"
|
| 49 |
-
API_PORT: int = 7860
|
| 50 |
-
|
| 51 |
-
# Column name in your CSV that contains the sign labels
|
| 52 |
CSV_LABEL_COLUMN: str = "label"
|
| 53 |
|
| 54 |
|
|
@@ -236,29 +221,30 @@ class SemanticSignMatcher:
|
|
| 236 |
return self._normalizer.normalize_label(label)
|
| 237 |
return label
|
| 238 |
|
| 239 |
-
|
| 240 |
if not os.path.exists(csv_path):
|
| 241 |
logger.info("CSV not found locally. Downloading from Hugging Face...")
|
| 242 |
-
import urllib.request
|
| 243 |
url = "https://huggingface.co/spaces/SondosM/avatarAPI/resolve/main/arabic_sign_lang_features.csv"
|
| 244 |
try:
|
| 245 |
urllib.request.urlretrieve(url, csv_path)
|
| 246 |
logger.info("CSV downloaded successfully.")
|
| 247 |
except Exception as e:
|
| 248 |
logger.warning(f"Failed to download CSV: {e}. No word signs loaded.")
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
df = pd.read_csv(csv_path, low_memory=False)
|
| 252 |
-
if label_column not in df.columns:
|
| 253 |
-
raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
|
| 254 |
-
all_labels = df[label_column].dropna().unique().tolist()
|
| 255 |
-
arabic_labels = [
|
| 256 |
-
str(l) for l in all_labels
|
| 257 |
-
if isinstance(l, str) and any("\u0600" <= c <= "\u06ff" for c in str(l))
|
| 258 |
-
]
|
| 259 |
-
self._raw_labels = arabic_labels
|
| 260 |
-
self._word_signs = arabic_labels.copy()
|
| 261 |
-
logger.info(f"Database: {len(arabic_labels)} Arabic word labels loaded.")
|
| 262 |
def _finalize_labels(self):
|
| 263 |
if self._normalizer and self._raw_labels:
|
| 264 |
self._word_signs = [self._normalize_label(l) for l in self._raw_labels]
|
|
@@ -282,26 +268,28 @@ class SemanticSignMatcher:
|
|
| 282 |
return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
|
| 283 |
norm_word = self._normalize_label(word_text)
|
| 284 |
norm_lemma = self._normalize_label(lemma) if lemma else ""
|
|
|
|
| 285 |
if norm_word in self._word_signs:
|
| 286 |
idx = self._word_signs.index(norm_word)
|
| 287 |
return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
|
|
|
|
| 288 |
if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
|
| 289 |
idx = self._word_signs.index(norm_lemma)
|
| 290 |
return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
|
|
|
|
| 291 |
if self._model is None or self._sign_embeddings is None:
|
| 292 |
return SignMatch(False, "", 0.0, "none")
|
|
|
|
| 293 |
candidates = list({norm_word, norm_lemma} - {""})
|
| 294 |
embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
|
| 295 |
scores = util.cos_sim(embs, self._sign_embeddings)
|
| 296 |
best_val = float(scores.max())
|
| 297 |
best_idx = int(scores.argmax() % len(self._word_signs))
|
|
|
|
| 298 |
if best_val >= self.threshold:
|
| 299 |
return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
|
| 300 |
return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
|
| 301 |
|
| 302 |
-
def letter_to_label(self, arabic_letter: str) -> Optional[str]:
|
| 303 |
-
return ARABIC_LETTER_TO_LABEL.get(arabic_letter)
|
| 304 |
-
|
| 305 |
@property
|
| 306 |
def available_signs(self) -> List[str]:
|
| 307 |
return self._raw_labels.copy()
|
|
@@ -376,6 +364,7 @@ class BlenderSequenceWriter:
|
|
| 376 |
missing_files = self._check_missing_keypoints(plan)
|
| 377 |
with open(self.output_path, "w", encoding="utf-8") as f:
|
| 378 |
f.write("\n".join(identifiers))
|
|
|
|
| 379 |
sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
|
| 380 |
letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
|
| 381 |
return {
|
|
@@ -451,8 +440,8 @@ logger.info("All components ready.")
|
|
| 451 |
|
| 452 |
# ----- FastAPI App -----
|
| 453 |
class TranslateRequest(BaseModel):
|
| 454 |
-
text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000
|
| 455 |
-
save_sequence: bool = Field(default=False
|
| 456 |
|
| 457 |
|
| 458 |
class StepDetail(BaseModel):
|
|
@@ -474,11 +463,7 @@ class TranslateResponse(BaseModel):
|
|
| 474 |
detailed_plan: List[StepDetail]
|
| 475 |
|
| 476 |
|
| 477 |
-
app = FastAPI(
|
| 478 |
-
title="Arabic Sign Language NLP API",
|
| 479 |
-
description="Translates Arabic text (Fus-ha and Ammiya) into sign animation sequences.",
|
| 480 |
-
version="1.0.0",
|
| 481 |
-
)
|
| 482 |
|
| 483 |
app.add_middleware(
|
| 484 |
CORSMiddleware,
|
|
@@ -503,8 +488,10 @@ def translate_post(request: TranslateRequest):
|
|
| 503 |
result = translator.translate(request.text, save_to_file=request.save_sequence)
|
| 504 |
except Exception as e:
|
| 505 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 506 |
if result["status"] == "error":
|
| 507 |
raise HTTPException(status_code=422, detail=result["message"])
|
|
|
|
| 508 |
return TranslateResponse(
|
| 509 |
status=result["status"],
|
| 510 |
input_text=request.text,
|
|
@@ -514,36 +501,10 @@ def translate_post(request: TranslateRequest):
|
|
| 514 |
letter_count=result.get("letter_count", 0),
|
| 515 |
missing_keypoint_files=result.get("missing_keypoint_files", []),
|
| 516 |
detailed_plan=[
|
| 517 |
-
StepDetail(
|
| 518 |
-
confidence=s["confidence"], method=s["method"])
|
| 519 |
-
for s in result.get("detailed_plan", [])
|
| 520 |
],
|
| 521 |
)
|
| 522 |
|
| 523 |
-
|
| 524 |
-
@app.get("/translate")
|
| 525 |
-
def translate_get(
|
| 526 |
-
text: str = Query(description="Arabic text to translate"),
|
| 527 |
-
save_sequence: bool = Query(default=False),
|
| 528 |
-
):
|
| 529 |
-
return translate_post(TranslateRequest(text=text, save_sequence=save_sequence))
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
@app.get("/signs")
|
| 533 |
-
def list_signs():
|
| 534 |
-
return {"total": len(sign_matcher.available_signs), "signs": sign_matcher.available_signs}
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
@app.get("/sequence-file")
|
| 538 |
-
def read_sequence_file():
|
| 539 |
-
path = Config.SEQUENCE_OUTPUT_PATH
|
| 540 |
-
if not os.path.exists(path):
|
| 541 |
-
raise HTTPException(status_code=404, detail="Sequence file not found. Run a translation first.")
|
| 542 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 543 |
-
lines = [line.strip() for line in f.readlines() if line.strip()]
|
| 544 |
-
return {"file_path": path, "sequence": lines, "count": len(lines)}
|
| 545 |
-
|
| 546 |
-
|
| 547 |
if __name__ == "__main__":
|
| 548 |
import uvicorn
|
| 549 |
-
uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)
|
|
|
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
import warnings
|
| 6 |
+
import urllib.request
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import List, Dict, Optional, Tuple
|
| 9 |
from dataclasses import dataclass, field
|
|
|
|
| 26 |
|
| 27 |
# ----- Project Configuration -----
|
| 28 |
class Config:
|
|
|
|
|
|
|
| 29 |
CSV_PATH: str = os.getenv("CSV_PATH", "arabic_sign_lang_features.csv")
|
|
|
|
|
|
|
| 30 |
KEYPOINTS_FOLDER: str = os.getenv("KEYPOINTS_FOLDER", "keypoints")
|
|
|
|
|
|
|
| 31 |
SEQUENCE_OUTPUT_PATH: str = "/tmp/sequence.txt"
|
|
|
|
|
|
|
| 32 |
EMBEDDING_MODEL: str = "aubmindlab/bert-base-arabertv2"
|
|
|
|
|
|
|
| 33 |
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.72"))
|
|
|
|
|
|
|
| 34 |
INCLUDE_PREPOSITION_WORDS: bool = False
|
|
|
|
|
|
|
| 35 |
API_HOST: str = "0.0.0.0"
|
| 36 |
+
API_PORT: int = 7860
|
|
|
|
|
|
|
| 37 |
CSV_LABEL_COLUMN: str = "label"
|
| 38 |
|
| 39 |
|
|
|
|
| 221 |
return self._normalizer.normalize_label(label)
|
| 222 |
return label
|
| 223 |
|
| 224 |
+
def _load_database(self, csv_path: str, label_column: str):
|
| 225 |
if not os.path.exists(csv_path):
|
| 226 |
logger.info("CSV not found locally. Downloading from Hugging Face...")
|
|
|
|
| 227 |
url = "https://huggingface.co/spaces/SondosM/avatarAPI/resolve/main/arabic_sign_lang_features.csv"
|
| 228 |
try:
|
| 229 |
urllib.request.urlretrieve(url, csv_path)
|
| 230 |
logger.info("CSV downloaded successfully.")
|
| 231 |
except Exception as e:
|
| 232 |
logger.warning(f"Failed to download CSV: {e}. No word signs loaded.")
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
df = pd.read_csv(csv_path, low_memory=False)
|
| 236 |
+
if label_column not in df.columns:
|
| 237 |
+
raise ValueError(f"Column '{label_column}' not found. Available: {list(df.columns)}")
|
| 238 |
+
|
| 239 |
+
all_labels = df[label_column].dropna().unique().tolist()
|
| 240 |
+
arabic_labels = [
|
| 241 |
+
str(l) for l in all_labels
|
| 242 |
+
if isinstance(l, str) and any("\u0600" <= c <= "\u06ff" for c in str(l))
|
| 243 |
+
]
|
| 244 |
+
self._raw_labels = arabic_labels
|
| 245 |
+
self._word_signs = arabic_labels.copy()
|
| 246 |
+
logger.info(f"Database: {len(arabic_labels)} Arabic word labels loaded.")
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def _finalize_labels(self):
|
| 249 |
if self._normalizer and self._raw_labels:
|
| 250 |
self._word_signs = [self._normalize_label(l) for l in self._raw_labels]
|
|
|
|
| 268 |
return SignMatch(found=False, sign_label="", confidence=0.0, method="none")
|
| 269 |
norm_word = self._normalize_label(word_text)
|
| 270 |
norm_lemma = self._normalize_label(lemma) if lemma else ""
|
| 271 |
+
|
| 272 |
if norm_word in self._word_signs:
|
| 273 |
idx = self._word_signs.index(norm_word)
|
| 274 |
return SignMatch(True, self._raw_labels[idx], 1.0, "exact")
|
| 275 |
+
|
| 276 |
if norm_lemma and norm_lemma != norm_word and norm_lemma in self._word_signs:
|
| 277 |
idx = self._word_signs.index(norm_lemma)
|
| 278 |
return SignMatch(True, self._raw_labels[idx], 0.95, "lemma")
|
| 279 |
+
|
| 280 |
if self._model is None or self._sign_embeddings is None:
|
| 281 |
return SignMatch(False, "", 0.0, "none")
|
| 282 |
+
|
| 283 |
candidates = list({norm_word, norm_lemma} - {""})
|
| 284 |
embs = self._model.encode(candidates, convert_to_tensor=True, device=self._device, batch_size=len(candidates))
|
| 285 |
scores = util.cos_sim(embs, self._sign_embeddings)
|
| 286 |
best_val = float(scores.max())
|
| 287 |
best_idx = int(scores.argmax() % len(self._word_signs))
|
| 288 |
+
|
| 289 |
if best_val >= self.threshold:
|
| 290 |
return SignMatch(True, self._raw_labels[best_idx], best_val, "semantic")
|
| 291 |
return SignMatch(False, self._raw_labels[best_idx] if self._raw_labels else "", best_val, "none")
|
| 292 |
|
|
|
|
|
|
|
|
|
|
| 293 |
@property
|
| 294 |
def available_signs(self) -> List[str]:
|
| 295 |
return self._raw_labels.copy()
|
|
|
|
| 364 |
missing_files = self._check_missing_keypoints(plan)
|
| 365 |
with open(self.output_path, "w", encoding="utf-8") as f:
|
| 366 |
f.write("\n".join(identifiers))
|
| 367 |
+
|
| 368 |
sign_steps = [s for s in plan if s.action_type == ActionType.SIGN]
|
| 369 |
letter_steps = [s for s in plan if s.action_type == ActionType.LETTER]
|
| 370 |
return {
|
|
|
|
| 440 |
|
| 441 |
# ----- FastAPI App -----
|
| 442 |
class TranslateRequest(BaseModel):
|
| 443 |
+
text: str = Field(description="Arabic input text (Fus-ha or Ammiya)", min_length=1, max_length=4000)
|
| 444 |
+
save_sequence: bool = Field(default=False)
|
| 445 |
|
| 446 |
|
| 447 |
class StepDetail(BaseModel):
|
|
|
|
| 463 |
detailed_plan: List[StepDetail]
|
| 464 |
|
| 465 |
|
| 466 |
+
app = FastAPI(title="Arabic Sign Language NLP API")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
app.add_middleware(
|
| 469 |
CORSMiddleware,
|
|
|
|
| 488 |
result = translator.translate(request.text, save_to_file=request.save_sequence)
|
| 489 |
except Exception as e:
|
| 490 |
raise HTTPException(status_code=500, detail=str(e))
|
| 491 |
+
|
| 492 |
if result["status"] == "error":
|
| 493 |
raise HTTPException(status_code=422, detail=result["message"])
|
| 494 |
+
|
| 495 |
return TranslateResponse(
|
| 496 |
status=result["status"],
|
| 497 |
input_text=request.text,
|
|
|
|
| 501 |
letter_count=result.get("letter_count", 0),
|
| 502 |
missing_keypoint_files=result.get("missing_keypoint_files", []),
|
| 503 |
detailed_plan=[
|
| 504 |
+
StepDetail(**s) for s in result.get("detailed_plan", [])
|
|
|
|
|
|
|
| 505 |
],
|
| 506 |
)
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
if __name__ == "__main__":
|
| 509 |
import uvicorn
|
| 510 |
+
uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)
|