Files changed (3) hide show
  1. main.py +31 -27
  2. normalize_bm_input.py +0 -80
  3. normalize_bm_output.py +0 -67
main.py CHANGED
@@ -2,17 +2,9 @@ import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
-
6
  # Note: Keep the imports together for clarity
7
- from transformers import (
8
- NllbTokenizer,
9
- AutoModelForSeq2SeqLM,
10
- Seq2SeqTrainer,
11
- Seq2SeqTrainingArguments,
12
- DataCollatorForSeq2Seq,
13
- )
14
- from normalize_bm_input import normalize_bm_input
15
- from normalize_bm_output import normalize_bm_output
16
 
17
  # =====================
18
  # 1️⃣ Environment / Cache
@@ -34,7 +26,7 @@ print(f"Using device: {device}")
34
  # =====================
35
  # Charger le modèle et le tokenizer NLLB
36
  try:
37
- model_name = "Gaoussin/Bamalingua-2"
38
  tokenizer = NllbTokenizer.from_pretrained(model_name)
39
  # Move model to the selected device (CPU or GPU)
40
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
@@ -48,13 +40,20 @@ except Exception as e:
48
  # =====================
49
  app = FastAPI()
50
 
51
-
52
  # Input schema
53
  class TranslationRequest(BaseModel):
54
  text: str
55
  src_lang: str # e.g., "bam_Latn"
56
  tgt_lang: str # e.g., "fra_Latn"
57
 
 
 
 
 
 
 
 
 
58
 
59
  # =====================
60
  # 5️⃣ Translation function - Restored to user's original logic
@@ -62,36 +61,41 @@ class TranslationRequest(BaseModel):
62
  def translateTo(text, src, tgt):
63
  tokenizer.src_lang = src
64
  tokenizer.tgt_lang = tgt
65
- print({text, tokenizer.src_lang, tokenizer.tgt_lang})
66
-
67
  # Prepare input for the model
68
  # We explicitly move the inputs to the same device as the model
69
  inputs = tokenizer(text, return_tensors="pt").to(device)
70
-
71
  # Generate translation using the user's logic
72
  output = model.generate(**inputs, max_length=128)
73
-
74
  # Decode the output
75
  return tokenizer.decode(output[0], skip_special_tokens=True)
76
 
77
-
78
  # =====================
79
  # 6️⃣ API Endpoints - Applying the Response Model
80
  # =====================
81
- @app.post("/translate")
82
  def translate(request: TranslationRequest):
83
  try:
84
- # --- 2. Core Translation ---
85
- result = translateTo(request.text, request.src_lang, request.tgt_lang)
86
- # --- 4. Final Output ---
87
- translation_list = [result, model_name]
88
- ###
89
- return [translation_list]
 
 
90
  except Exception as e:
91
  print(f"An error occurred during translation: {e}")
92
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
93
-
 
 
 
 
94
 
95
  @app.get("/")
96
  def root():
97
- return {"message": "API is running 🚀"}
 
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
 
5
  # Note: Keep the imports together for clarity
6
+ from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
7
+ from normalize_bm_words import normalize_text
 
 
 
 
 
 
 
8
 
9
  # =====================
10
  # 1️⃣ Environment / Cache
 
26
  # =====================
27
  # Charger le modèle et le tokenizer NLLB
28
  try:
29
+ model_name = "Gaoussin/bamalingua-4"
30
  tokenizer = NllbTokenizer.from_pretrained(model_name)
31
  # Move model to the selected device (CPU or GPU)
32
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
 
40
  # =====================
41
  app = FastAPI()
42
 
 
43
  # Input schema
44
  class TranslationRequest(BaseModel):
45
  text: str
46
  src_lang: str # e.g., "bam_Latn"
47
  tgt_lang: str # e.g., "fra_Latn"
48
 
49
+ # Output schema (THE FIX: ensures both fields are returned)
50
+ class TranslationResponse(BaseModel):
51
+ """
52
+ Ensures both the translated text and the app version ID are included
53
+ in the response JSON.
54
+ """
55
+ translation: str
56
+ appVersionId: str
57
 
58
  # =====================
59
  # 5️⃣ Translation function - Restored to user's original logic
 
61
  def translateTo(text, src, tgt):
62
  tokenizer.src_lang = src
63
  tokenizer.tgt_lang = tgt
64
+ print(tokenizer.src_lang, tokenizer.tgt_lang)
65
+
66
  # Prepare input for the model
67
  # We explicitly move the inputs to the same device as the model
68
  inputs = tokenizer(text, return_tensors="pt").to(device)
69
+
70
  # Generate translation using the user's logic
71
  output = model.generate(**inputs, max_length=128)
72
+
73
  # Decode the output
74
  return tokenizer.decode(output[0], skip_special_tokens=True)
75
 
 
76
  # =====================
77
  # 6️⃣ API Endpoints - Applying the Response Model
78
  # =====================
79
+ @app.post("/translate", response_model=TranslationResponse) # <-- Fix remains here
80
  def translate(request: TranslationRequest):
81
  try:
82
+ # normalize_text from imported file
83
+ text = normalize_text(request.text)
84
+ result = translateTo(text, request.src_lang, request.tgt_lang)
85
+ appVersionId = "App Version id = 2"
86
+
87
+ # Return the dictionary matching the TranslationResponse schema
88
+ return {"translation": result, "appVersionId": appVersionId}
89
+
90
  except Exception as e:
91
  print(f"An error occurred during translation: {e}")
92
+ # When raising an HTTPException, the response model is bypassed,
93
+ # and a standard JSON error is returned.
94
+ raise HTTPException(
95
+ status_code=500,
96
+ detail=f"Translation failed: {str(e)}"
97
+ )
98
 
99
  @app.get("/")
100
  def root():
101
+ return {"message": "API is running 🚀"}
normalize_bm_input.py DELETED
@@ -1,80 +0,0 @@
1
- import re
2
-
3
- # Define the de-contraction dictionary.
4
- # Keys are the contracted forms (what you want to replace).
5
- # Values are the expanded forms (what you want to replace them with).
6
- DE_CONTRACTIONS = {
7
- # Keys with apostrophes/special characters for multi-word expansion
8
- "k'a": "ka a",
9
- "a b'a": "a be a",
10
- "n'be": "ne be",
11
- "n'b'a":"ne be a",
12
- "b'a": "be a",
13
- "k'o": "ko o", # Corrected key-value based on original request
14
- "b'i": "be i",
15
- "k'i":"ka i",
16
- "k'aw":"ka aw",
17
-
18
- # Single-word keys (no apostrophe) for multi-word expansion
19
- "kɔkɔ": "kɔgɔ",
20
- "bɛ": "be"
21
- }
22
-
23
- def normalize_bm_input(text: str) -> str:
24
- """
25
- De-contracts (expands) specific contracted forms in a string
26
- based on the DE_CONTRACTIONS dictionary.
27
- """
28
-
29
- # 1. Ensure the text is lowercase for consistent matching
30
- text = text.lower()
31
-
32
- # --- Part 1: Handle Multi-Word Expansions ---
33
-
34
- # The condition for 'multi-word expansion' must check the VALUE (the expanded form)
35
- # not the KEY (the contracted form).
36
- multi_word_expansions = {k: v for k, v in DE_CONTRACTIONS.items() if ' ' in v}
37
-
38
- # Sort keys (contracted forms) by length descending. This is CRUCIAL
39
- # for regex to match longer contracted forms (e.g., "a b'a") before
40
- # shorter ones that might be contained within them.
41
- sorted_multi_word = sorted(multi_word_expansions.items(), key=lambda item: len(item[0]), reverse=True)
42
-
43
- # Apply replacement for contracted forms that expand to multi-word phrases
44
- for contracted_form, expanded_phrase in sorted_multi_word:
45
-
46
- # Create a pattern to match the full contracted form, ensuring it's
47
- # surrounded by word boundaries. This ensures "b'a" is not matched
48
- # within "b'adi".
49
- pattern = r'\b' + re.escape(contracted_form) + r'\b'
50
-
51
- # Replace the full matched pattern with the expanded phrase
52
- text = re.sub(pattern, expanded_phrase, text)
53
-
54
- # --- Part 2: Handle Single-Word Expansions (e.g., 'kɔkɔ' -> 'kɔgɔ') ---
55
-
56
- # Filter for contractions that expand to a single word (no spaces in the value)
57
- single_word_expansions = {k: v for k, v in DE_CONTRACTIONS.items() if ' ' not in v}
58
-
59
- def replace_single_word(match):
60
- """Looks up the matched word (key) and returns the single-word expansion (value)."""
61
- word = match.group(0)
62
- # Use .get() to replace only the words present in the dictionary.
63
- return single_word_expansions.get(word, word)
64
-
65
- # Apply the replacement function to all whole words
66
- # This also catches cases like kɔkɔ and bɛ.
67
- text = re.sub(r'\b\S+\b', replace_single_word, text)
68
-
69
- # 2. Capitalize the first letter of the result for presentation
70
- return text[:1].upper() + text[1:]
71
-
72
- # --- Example Usage ---
73
-
74
- #input_text_4 = "k'a di a b'i fɛ kɔkɔ n'b'a fɔ. Bɛ jɛ."
75
-
76
- #print(f"Original Text: {input_text_4}")
77
- #normalized_4 = normalize_bm_input(input_text_4)
78
- #print(f"Normalized Text: {normalized_4}\n")
79
-
80
- # Expected Output: Ka a di a be i fɛ kɔgɔ ne be a fɔ. Be jɛ.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
normalize_bm_output.py DELETED
@@ -1,67 +0,0 @@
1
- import re
2
-
3
- # Define the contractions dictionary
4
- CONTRACTIONS = {
5
- # Multi-word contractions (keys are space-separated)
6
- "ka a": "k'a",
7
- "a be a": "a b'a",
8
- "be a": "b'a",
9
- "ko o": "k'o",
10
- "di i":"d'i",
11
- "be i":"b'i"
12
- # Example Single-word contraction added:
13
- #"kaa": "k'aa" # Assuming this is a desired single-word contraction
14
- }
15
-
16
- def normalize_bm_output(text: str) -> str:
17
- """
18
- Normalizes specific contractions (both single-word and multi-word)
19
- in a string.
20
- """
21
-
22
- # 1. Ensure the text is lowercase as specified in your requirement
23
- text = text.lower()
24
-
25
- # --- Part 1: Handle Multi-Word Contractions ---
26
-
27
- # Filter for and sort multi-word keys by length descending to prevent partial matches
28
- multi_word_contractions = {k: v for k, v in CONTRACTIONS.items() if ' ' in k}
29
- sorted_multi_word = sorted(multi_word_contractions.items(), key=lambda item: len(item[0]), reverse=True)
30
-
31
- # Apply replacement for multi-word phrases
32
- for original_phrase, contracted_form in sorted_multi_word:
33
- # Create a pattern to match the full phrase, ensuring it's surrounded by
34
- # word boundaries or start/end of string.
35
- # re.escape handles any special characters in the key
36
- pattern = r'\b' + re.escape(original_phrase) + r'\b'
37
-
38
- # Replace the full matched pattern with the contracted form
39
- text = re.sub(pattern, contracted_form, text, flags=re.IGNORECASE)
40
-
41
- # --- Part 2: Handle Single-Word Contractions ---
42
-
43
- # Filter for single-word keys (no spaces)
44
- single_word_contractions = {k: v for k, v in CONTRACTIONS.items() if ' ' not in k}
45
-
46
- # Use a regular expression and a function to map the words based on the dictionary
47
-
48
- def replace_single_word(match):
49
- """Looks up the matched word in the single-word contractions dictionary."""
50
- word = match.group(0)
51
- # Use .get() with the original word as the default to ensure non-contracted
52
- # words are left alone.
53
- return single_word_contractions.get(word, word)
54
-
55
- # The pattern r'\b\w+\b' matches every single whole word in the text.
56
- # The replacement function replace_single_word is called for every match.
57
- text = re.sub(r'\b\w+\b', replace_single_word, text)
58
-
59
- return text[:1].upper() + text[1:]
60
-
61
- # --- Example Usage with both types of contractions ---
62
-
63
- #input_text_4 = "ka a di a be i fɛ kɔgɔ ne be a fɔ."
64
-
65
- #print(f"Original Text: {input_text_4}")
66
- #normalized_4 = normalize_bm_output(input_text_4)
67
- #print(f"Normalized Text: {normalized_4}\n")