Arsive2 commited on
Commit
eb52047
·
1 Parent(s): fb3dfc3

Updated to use a smaller model

Browse files
Files changed (4) hide show
  1. Dockerfile +5 -1
  2. api_server.py +38 -38
  3. app/models/translation_model.py +156 -86
  4. requirements.txt +3 -1
Dockerfile CHANGED
@@ -1,4 +1,5 @@
1
  FROM python:3.10-bullseye
 
2
  WORKDIR /app
3
 
4
  # Install system dependencies
@@ -34,6 +35,9 @@ EXPOSE 7860
34
 
35
  # Set environment variables
36
  ENV PYTHONUNBUFFERED=1
 
 
 
37
 
38
  # Run the API server
39
- CMD ["uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.10-bullseye
2
+
3
  WORKDIR /app
4
 
5
  # Install system dependencies
 
35
 
36
  # Set environment variables
37
  ENV PYTHONUNBUFFERED=1
38
+ ENV OMP_NUM_THREADS=4
39
+ ENV MKL_NUM_THREADS=4
40
+ ENV TORCH_CPU_NUM_THREADS=4
41
 
42
  # Run the API server
43
+ CMD ["uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "900"]
api_server.py CHANGED
@@ -6,6 +6,10 @@ import torch
6
  import os
7
  import logging
8
  import uvicorn
 
 
 
 
9
 
10
  # Configure logging
11
  logging.basicConfig(
@@ -30,24 +34,13 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- # Set environment variables if not already set
34
- os.environ.setdefault('TRANSFORMERS_CACHE', '/app/.cache')
35
- os.environ.setdefault('HF_HOME', '/app/.cache')
36
- os.environ.setdefault('NLTK_DATA', '/app/nltk_data')
37
-
38
- # Create necessary directories with proper permissions
39
- os.makedirs(os.environ.get('TRANSFORMERS_CACHE'), exist_ok=True)
40
- os.makedirs(os.environ.get('NLTK_DATA'), exist_ok=True)
41
-
42
  try:
43
- from app.models.text_chunker import TextChunker
44
- from app.models.html_processor import HTMLProcessor
45
- from app.models.translation_model import TranslationModel
46
-
47
- # Initialize components
48
- text_chunker = TextChunker(max_tokens=250, overlap_tokens=30)
49
- html_processor = HTMLProcessor()
50
  model = TranslationModel()
 
 
 
51
 
52
  initialization_error = None
53
  except Exception as e:
@@ -80,7 +73,7 @@ async def root():
80
  "message": "Service initialization failed",
81
  "error": initialization_error
82
  }
83
- return {"status": "ok", "model": "MADLAD-400", "version": "3B"}
84
 
85
  @app.get("/health")
86
  async def health_check():
@@ -89,12 +82,10 @@ async def health_check():
89
  "status": "ok" if not initialization_error else "error",
90
  "error": initialization_error,
91
  "environment": {
92
- "transformers_cache": os.environ.get('TRANSFORMERS_CACHE'),
93
- "hf_home": os.environ.get('HF_HOME'),
94
- "nltk_data": os.environ.get('NLTK_DATA'),
95
  "python_version": os.environ.get('PYTHON_VERSION'),
96
- "cuda_available": torch.cuda.is_available() if 'torch' in globals() else "Unknown",
97
- "device": str(model.device) if 'model' in globals() and hasattr(model, 'device') else "Unknown"
 
98
  }
99
  }
100
 
@@ -105,7 +96,10 @@ async def translate_text(request: TranslationRequest):
105
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
106
 
107
  try:
108
- # Get chunks using TextChunker
 
 
 
109
  chunks = text_chunker.create_chunks(request.text)
110
  translated_chunks = []
111
 
@@ -143,17 +137,23 @@ async def translate_html(request: HTMLTranslationRequest):
143
 
144
  # Process each text fragment individually
145
  translated_fragments = []
146
- for fragment in text_fragments:
147
- if not fragment.strip():
148
- translated_fragments.append(fragment)
149
- continue
150
-
151
- translated_text = model.translate(
152
- fragment,
153
- request.source_lang_code,
154
- request.target_lang_code
155
- )
156
- translated_fragments.append(translated_text)
 
 
 
 
 
 
157
 
158
  # Replace the original text with translated text in the HTML structure
159
  translated_html = html_processor.replace_text(dom_data, translated_fragments)
@@ -179,9 +179,9 @@ async def process_document(
179
  file_content = await file.read()
180
 
181
  # Process document to extract text
182
- extracted_text = model.process_document(
183
- file_content,
184
- file.filename,
185
  use_ocr=use_ocr
186
  )
187
 
@@ -191,7 +191,7 @@ async def process_document(
191
  detail="No text could be extracted from the document"
192
  )
193
 
194
- # Translate the extracted text
195
  translated_text = model.translate(
196
  extracted_text,
197
  source_lang_code,
 
6
  import os
7
  import logging
8
  import uvicorn
9
+ from app.models.translation_model import TranslationModel
10
+ from app.models.html_processor import HTMLProcessor
11
+ from app.models.text_chunker import TextChunker
12
+ from app.models.document_processor import DocumentProcessor
13
 
14
  # Configure logging
15
  logging.basicConfig(
 
34
  allow_headers=["*"],
35
  )
36
 
37
+ # Initialize model components
 
 
 
 
 
 
 
 
38
  try:
39
+ # Use the CPU-optimized translation model
 
 
 
 
 
 
40
  model = TranslationModel()
41
+ html_processor = HTMLProcessor()
42
+ text_chunker = TextChunker(max_tokens=250, overlap_tokens=30)
43
+ document_processor = DocumentProcessor()
44
 
45
  initialization_error = None
46
  except Exception as e:
 
73
  "message": "Service initialization failed",
74
  "error": initialization_error
75
  }
76
+ return {"status": "ok", "model": "OPUS-MT/NLLB-CPU-Optimized", "version": "1.0"}
77
 
78
  @app.get("/health")
79
  async def health_check():
 
82
  "status": "ok" if not initialization_error else "error",
83
  "error": initialization_error,
84
  "environment": {
 
 
 
85
  "python_version": os.environ.get('PYTHON_VERSION'),
86
+ "cuda_available": torch.cuda.is_available(),
87
+ "device": str(model.device) if hasattr(model, 'device') else "Unknown",
88
+ "loaded_models": list(model.opus_mt_models.keys()) if hasattr(model, 'opus_mt_models') else []
89
  }
90
  }
91
 
 
96
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
97
 
98
  try:
99
+ # Using the OPUS-MT/NLLB hybrid model for more efficient translation
100
+ logger.info(f"Translating from {request.source_lang_code} to {request.target_lang_code}")
101
+
102
+ # Create chunks using TextChunker for long texts
103
  chunks = text_chunker.create_chunks(request.text)
104
  translated_chunks = []
105
 
 
137
 
138
  # Process each text fragment individually
139
  translated_fragments = []
140
+
141
+ # Process in smaller batches to avoid timeouts
142
+ batch_size = 10
143
+ for i in range(0, len(text_fragments), batch_size):
144
+ batch = text_fragments[i:i+batch_size]
145
+
146
+ for fragment in batch:
147
+ if not fragment.strip():
148
+ translated_fragments.append(fragment)
149
+ continue
150
+
151
+ translated_text = model.translate(
152
+ fragment,
153
+ request.source_lang_code,
154
+ request.target_lang_code
155
+ )
156
+ translated_fragments.append(translated_text)
157
 
158
  # Replace the original text with translated text in the HTML structure
159
  translated_html = html_processor.replace_text(dom_data, translated_fragments)
 
179
  file_content = await file.read()
180
 
181
  # Process document to extract text
182
+ extracted_text = document_processor.process_document(
183
+ file_data=file_content,
184
+ filename=file.filename,
185
  use_ocr=use_ocr
186
  )
187
 
 
191
  detail="No text could be extracted from the document"
192
  )
193
 
194
+ # Translate the extracted text using our more efficient model
195
  translated_text = model.translate(
196
  extracted_text,
197
  source_lang_code,
app/models/translation_model.py CHANGED
@@ -3,35 +3,37 @@ import logging
3
  import re
4
  import os
5
  from typing import Optional, Dict, Any, List
6
- from transformers import T5ForConditionalGeneration, T5Tokenizer
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class TranslationModel:
11
  """
12
- Model class for handling the translation functionality using MADLAD-400 model
13
  """
14
 
15
- def __init__(self, model_name: str = "google/madlad400-3b-mt"):
16
  """
17
- Initialize the translation model.
18
 
19
  Args:
20
- model_name: Name of the Hugging Face model to use
21
  """
22
- self.model_name = model_name
23
- self.model = None
24
- self.tokenizer = None
25
  self.device = self._get_device()
 
 
 
26
  self.initialized = False
27
  self.initialization_error = None
28
 
29
- # Ensure cache directory exists and is writable
30
- cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/app/.cache')
31
- os.makedirs(cache_dir, exist_ok=True)
32
 
33
  try:
34
- self._load_model()
 
35
  self.initialized = True
36
  except Exception as e:
37
  self.initialization_error = str(e)
@@ -42,43 +44,91 @@ class TranslationModel:
42
  if torch.cuda.is_available():
43
  logger.info("Using CUDA GPU for translation")
44
  return torch.device("cuda")
45
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
46
- logger.info("Using Apple MPS (Metal) for translation")
47
- return torch.device("mps")
48
  else:
49
  logger.info("Using CPU for translation")
50
  return torch.device("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def _load_model(self):
53
- """Load the MADLAD-400 3B translation model."""
54
  try:
55
- logger.info(f"Loading translation model: {self.model_name}")
56
- self.tokenizer = T5Tokenizer.from_pretrained(
57
- self.model_name,
58
- cache_dir=os.environ.get('TRANSFORMERS_CACHE', '/app/.cache')
 
 
 
 
 
59
  )
 
60
 
61
- # Use torch_dtype=torch.bfloat16 if available for faster inference
62
- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
63
- logger.info("Using bfloat16 precision for model loading")
64
- self.model = T5ForConditionalGeneration.from_pretrained(
65
- self.model_name,
66
- torch_dtype=torch.bfloat16,
67
- cache_dir=os.environ.get('TRANSFORMERS_CACHE', '/app/.cache')
68
- )
69
- else:
70
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
71
- logger.info(f"Using {dtype} precision for model loading")
72
- self.model = T5ForConditionalGeneration.from_pretrained(
73
- self.model_name,
74
- torch_dtype=dtype,
75
- cache_dir=os.environ.get('TRANSFORMERS_CACHE', '/app/.cache')
76
- )
77
 
78
- self.model.to(self.device)
79
- logger.info(f"Model loaded successfully on {self.device}")
80
  except Exception as e:
81
- logger.error(f"Error loading model: {str(e)}")
82
  raise
83
 
84
  def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str:
@@ -97,57 +147,77 @@ class TranslationModel:
97
  if not self.initialized:
98
  raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
99
 
100
- # Prepare input with MADLAD-400 format: <2{target_lang}> {source_text}
101
- input_text = f"<2{target_lang_code}> {text}"
102
 
103
- inputs = self.tokenizer(
104
- input_text,
105
- return_tensors="pt",
106
- padding=True,
107
- truncation=True,
108
- max_length=512
109
- )
110
-
111
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
112
-
113
- with torch.no_grad():
114
- translated = self.model.generate(
115
- **inputs,
116
- max_length=512,
117
- num_beams=5,
118
- early_stopping=True
119
- )
120
-
121
- translated_text = self.tokenizer.batch_decode(
122
- translated,
123
- skip_special_tokens=True
124
- )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
 
126
  return re.sub(r'\s+', ' ', translated_text).strip()
127
 
128
  except Exception as e:
129
  logger.error(f"Translation error: {str(e)}")
130
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def process_document(self, file_data: bytes, filename: str, use_ocr: bool = False) -> str:
133
- """
134
- Process document to extract text using PyMuPDF and optional OCR.
135
-
136
- Args:
137
- file_data: Raw file content
138
- filename: Original filename
139
- use_ocr: Whether to use OCR for text extraction
140
-
141
- Returns:
142
- Extracted text as string
143
- """
144
- if not self.initialized:
145
- raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
146
-
147
- from app.models.document_processor import DocumentProcessor
148
-
149
- # Initialize document processor
150
- doc_processor = DocumentProcessor()
151
-
152
- # Process document and extract text
153
- return doc_processor.process_document(file_data, filename, use_ocr)
 
3
  import re
4
  import os
5
  from typing import Optional, Dict, Any, List
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from tqdm import tqdm
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class TranslationModel:
12
  """
13
+ More efficient translation model that uses smaller models optimized for CPU
14
  """
15
 
16
+ def __init__(self, model_cache_dir: str = ".cache/models"):
17
  """
18
+ Initialize the translation model manager.
19
 
20
  Args:
21
+ model_cache_dir: Directory to cache downloaded models
22
  """
23
+ self.model_cache_dir = model_cache_dir
 
 
24
  self.device = self._get_device()
25
+ self.opus_mt_models = {} # Cache for loaded OPUS-MT models
26
+ self.fallback_model = None
27
+ self.fallback_tokenizer = None
28
  self.initialized = False
29
  self.initialization_error = None
30
 
31
+ # Create cache directory
32
+ os.makedirs(model_cache_dir, exist_ok=True)
 
33
 
34
  try:
35
+ # Initialize the fallback model (loads when first needed)
36
+ logger.info("TranslationModel initialized - models will be loaded on demand")
37
  self.initialized = True
38
  except Exception as e:
39
  self.initialization_error = str(e)
 
44
  if torch.cuda.is_available():
45
  logger.info("Using CUDA GPU for translation")
46
  return torch.device("cuda")
 
 
 
47
  else:
48
  logger.info("Using CPU for translation")
49
  return torch.device("cpu")
50
+
51
+ def _get_opus_mt_model_name(self, source_lang_code: str, target_lang_code: str) -> Optional[str]:
52
+ """Get the appropriate OPUS-MT model name for the language pair."""
53
+ # OPUS-MT uses different language codes in some cases
54
+ lang_code_mapping = {
55
+ 'zh': 'zho',
56
+ 'en': 'eng',
57
+ 'ar': 'ara',
58
+ 'fr': 'fra',
59
+ 'de': 'deu',
60
+ 'ru': 'rus',
61
+ 'pt': 'por',
62
+ 'es': 'spa',
63
+ 'it': 'ita',
64
+ 'nl': 'nld',
65
+ 'pl': 'pol',
66
+ 'ja': 'jpn',
67
+ 'ko': 'kor',
68
+ }
69
+
70
+ source = lang_code_mapping.get(source_lang_code, source_lang_code)
71
+ target = lang_code_mapping.get(target_lang_code, target_lang_code)
72
+
73
+ # Try direct model first
74
+ model_name = f"Helsinki-NLP/opus-mt-{source}-{target}"
75
+ return model_name
76
+
77
+ def _load_opus_mt_model(self, source_lang_code: str, target_lang_code: str):
78
+ """Load an OPUS-MT model for the specific language pair."""
79
+ model_name = self._get_opus_mt_model_name(source_lang_code, target_lang_code)
80
+
81
+ # Check if model already loaded
82
+ key = f"{source_lang_code}-{target_lang_code}"
83
+ if key in self.opus_mt_models:
84
+ return self.opus_mt_models[key]
85
+
86
+ try:
87
+ logger.info(f"Loading OPUS-MT model: {model_name}")
88
+
89
+ # Load with half precision to save memory on CPU
90
+ model = AutoModelForSeq2SeqLM.from_pretrained(
91
+ model_name,
92
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
93
+ cache_dir=self.model_cache_dir,
94
+ low_cpu_mem_usage=True
95
+ )
96
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
97
+
98
+ model.to(self.device)
99
+ logger.info(f"OPUS-MT model loaded successfully: {model_name}")
100
+
101
+ # Cache the model
102
+ self.opus_mt_models[key] = (model, tokenizer)
103
+ return model, tokenizer
104
+
105
+ except Exception as e:
106
+ logger.warning(f"Could not load OPUS-MT model {model_name}: {str(e)}")
107
+ return None
108
+
109
+ def _load_fallback_model(self):
110
+ """Load the fallback NLLB-200 model for language pairs without OPUS-MT models."""
111
+ if self.fallback_model is not None:
112
+ return
113
 
 
 
114
  try:
115
+ # Use the small distilled version for efficiency on CPU
116
+ model_name = "facebook/nllb-200-distilled-600M"
117
+ logger.info(f"Loading fallback model: {model_name}")
118
+
119
+ self.fallback_model = AutoModelForSeq2SeqLM.from_pretrained(
120
+ model_name,
121
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
122
+ cache_dir=self.model_cache_dir,
123
+ low_cpu_mem_usage=True
124
  )
125
+ self.fallback_tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=self.model_cache_dir)
126
 
127
+ self.fallback_model.to(self.device)
128
+ logger.info(f"Fallback model loaded successfully: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
130
  except Exception as e:
131
+ logger.error(f"Error loading fallback model: {str(e)}")
132
  raise
133
 
134
  def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str:
 
147
  if not self.initialized:
148
  raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
149
 
150
+ # Try to use OPUS-MT model first (faster and often better quality)
151
+ opus_mt_result = self._load_opus_mt_model(source_lang_code, target_lang_code)
152
 
153
+ if opus_mt_result:
154
+ model, tokenizer = opus_mt_result
155
+
156
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
157
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
158
+
159
+ with torch.no_grad():
160
+ outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
161
+
162
+ translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
163
+ logger.info(f"Translation completed using OPUS-MT model")
164
+
165
+ else:
166
+ # Fall back to NLLB model
167
+ logger.info(f"No OPUS-MT model available for {source_lang_code}-{target_lang_code}, using fallback model")
168
+ self._load_fallback_model()
169
+
170
+ # NLLB uses a specific format for inputs
171
+ tokenizer = self.fallback_tokenizer
172
+ model = self.fallback_model
173
+
174
+ # Prepare input with NLLB format
175
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
176
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
177
+
178
+ # NLLB language codes are like "eng_Latn", "fra_Latn", etc.
179
+ nllb_source = _get_nllb_code(source_lang_code)
180
+ nllb_target = _get_nllb_code(target_lang_code)
181
+
182
+ # Force decoder to start with target language token
183
+ forced_bos_token_id = tokenizer.lang_code_to_id[nllb_target]
184
+
185
+ with torch.no_grad():
186
+ outputs = model.generate(
187
+ **inputs,
188
+ forced_bos_token_id=forced_bos_token_id,
189
+ max_length=512,
190
+ num_beams=4,
191
+ early_stopping=True
192
+ )
193
+
194
+ translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
195
+ logger.info(f"Translation completed using fallback NLLB model")
196
 
197
+ # Clean up the output
198
  return re.sub(r'\s+', ' ', translated_text).strip()
199
 
200
  except Exception as e:
201
  logger.error(f"Translation error: {str(e)}")
202
  raise
203
+
204
+ def _get_nllb_code(lang_code: str) -> str:
205
+ """Convert ISO language code to NLLB language code format."""
206
+ # Mapping for common languages
207
+ nllb_mapping = {
208
+ 'en': 'eng_Latn',
209
+ 'fr': 'fra_Latn',
210
+ 'es': 'spa_Latn',
211
+ 'de': 'deu_Latn',
212
+ 'it': 'ita_Latn',
213
+ 'pt': 'por_Latn',
214
+ 'nl': 'nld_Latn',
215
+ 'ru': 'rus_Cyrl',
216
+ 'zh': 'zho_Hans',
217
+ 'ar': 'ara_Arab',
218
+ 'hi': 'hin_Deva',
219
+ 'ja': 'jpn_Jpan',
220
+ 'ko': 'kor_Hang',
221
+ }
222
 
223
+ return nllb_mapping.get(lang_code, f"{lang_code}_Latn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,10 +4,12 @@ pydantic==1.10.7
4
  transformers==4.30.2
5
  sentencepiece==0.1.99
6
  accelerate==0.20.3
 
7
  python-multipart==0.0.6
8
  pillow==9.5.0
9
  nltk==3.8.1
10
  tqdm==4.65.0
11
  beautifulsoup4==4.12.2
12
  PyMuPDF==1.22.5
13
- protobuf==3.20.3
 
 
4
  transformers==4.30.2
5
  sentencepiece==0.1.99
6
  accelerate==0.20.3
7
+ optimum==1.8.8
8
  python-multipart==0.0.6
9
  pillow==9.5.0
10
  nltk==3.8.1
11
  tqdm==4.65.0
12
  beautifulsoup4==4.12.2
13
  PyMuPDF==1.22.5
14
+ protobuf==3.20.3
15
+ torch==2.0.1