Arsive2 commited on
Commit
afd2cc6
·
1 Parent(s): 8720cc4

Ctranslate performance upgrade

Browse files
Dockerfile CHANGED
@@ -1,43 +1,48 @@
1
- FROM python:3.10-bullseye
2
 
3
- WORKDIR /app
 
 
 
 
 
 
4
 
5
- # Install system dependencies
6
- RUN apt-get update && apt-get install -y \
7
  build-essential \
8
- libffi-dev \
9
  git \
 
 
 
 
 
 
 
 
 
 
 
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
- # Set up directories with proper permissions
13
- RUN mkdir -p /app/.cache /app/nltk_data && \
14
- chmod 777 /app/.cache /app/nltk_data
15
-
16
- # Set environment variables for cache directories
17
- ENV TRANSFORMERS_CACHE=/app/.cache
18
- ENV HF_HOME=/app/.cache
19
- ENV NLTK_DATA=/app/nltk_data
20
 
21
- # Copy requirements file
22
  COPY requirements.txt .
 
 
 
23
 
24
- # Install Python dependencies
25
- RUN pip install --no-cache-dir -r requirements.txt
26
-
27
- # Pre-download NLTK data before copying application code
28
  RUN python -c "import nltk; nltk.download('punkt', download_dir='/app/nltk_data')"
29
 
30
- # Copy application code
31
- COPY . .
 
32
 
33
- # Expose the port for the API
34
- EXPOSE 7860
35
 
36
- # Set environment variables
37
- ENV PYTHONUNBUFFERED=1
38
- ENV OMP_NUM_THREADS=4
39
- ENV MKL_NUM_THREADS=4
40
- ENV TORCH_CPU_NUM_THREADS=4
41
 
42
- # Run the API server
43
- CMD ["uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "900"]
 
1
+ FROM python:3.10-slim
2
 
3
+ LABEL maintainer="Arsive <arsive.ai@gmail.com>"
4
+ LABEL description="Universal Translator API with CTranslate2 optimization"
5
+
6
+ ENV PYTHONDONTWRITEBYTECODE=1 \
7
+ PYTHONUNBUFFERED=1 \
8
+ CT2_MODEL_CACHE=/app/.cache/ct2_models \
9
+ NLTK_DATA=/app/nltk_data
10
 
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
 
12
  build-essential \
13
+ python3-dev \
14
  git \
15
+ curl \
16
+ wget \
17
+ unzip \
18
+ cmake \
19
+ pkg-config \
20
+ libpoppler-cpp-dev \
21
+ poppler-utils \
22
+ libsm6 \
23
+ libxext6 \
24
+ libxrender-dev \
25
+ libgl1-mesa-glx \
26
+ libglib2.0-0 \
27
  && rm -rf /var/lib/apt/lists/*
28
 
29
+ WORKDIR /app
30
+ RUN mkdir -p /app/app/models /app/uploads /app/.cache/ct2_models /app/nltk_data /app/translation_logs
 
 
 
 
 
 
31
 
 
32
  COPY requirements.txt .
33
+ RUN pip install --upgrade pip && \
34
+ pip install torch==2.0.1 && \
35
+ pip install --no-cache-dir -r requirements.txt
36
 
 
 
 
 
37
  RUN python -c "import nltk; nltk.download('punkt', download_dir='/app/nltk_data')"
38
 
39
+ COPY app/ /app/app/
40
+ COPY *.py /app/
41
+ COPY fix_permissions.sh /app/
42
 
43
+ RUN chmod +x /app/fix_permissions.sh && \
44
+ /app/fix_permissions.sh
45
 
46
+ EXPOSE 8000
 
 
 
 
47
 
48
+ CMD ["gunicorn", "-b", "0.0.0.0:8000", "--timeout", "300", "--workers", "1", "--threads", "4", "app:app"]
 
README.md CHANGED
@@ -9,7 +9,163 @@ license: mit
9
  short_description: Language translation space
10
  ---
11
 
12
- # Universal Translator API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  This is a Hugging Face Spaces deployment of the Universal Translator API service, which provides translation capabilities using the MADLAD-400 3B model.
15
 
 
9
  short_description: Language translation space
10
  ---
11
 
12
+ # Universal Translator with CTranslate2 Optimization
13
+
14
+ This project implements a high-performance language translation service optimized with CTranslate2, supporting 450+ languages including special handling for Dravidian languages (Tamil, Telugu, Kannada, Malayalam).
15
+
16
+ ## 🚀 Performance Improvements
17
+
18
+ CTranslate2 is a custom inference engine for Transformer models that provides significant speed and memory improvements:
19
+
20
+ - **5-10x faster translation** compared to standard Transformers library
21
+ - **Reduced memory usage** through model quantization
22
+ - **Batch processing** for improved throughput
23
+ - **Hardware optimization** for both CPU and GPU
24
+
25
+ ## 🔧 Key Features
26
+
27
+ - Text, HTML, and document (PDF) translation
28
+ - Special handling for Dravidian languages with language-specific tags
29
+ - Optimized batch processing for improved performance
30
+ - Docker support for easy deployment
31
+ - GPU acceleration when available
32
+
33
+ ## 📋 Requirements
34
+
35
+ - Python 3.8+
36
+ - CTranslate2 3.20.0+
37
+ - Transformers 4.28.0+
38
+ - PyTorch 2.0.0+
39
+ - Flask 2.2.3+
40
+ - Other dependencies in requirements.txt
41
+
42
+ ## 💻 Installation
43
+
44
+ ### Using Docker
45
+
46
+ ```bash
47
+ # Build the Docker image
48
+ docker build -t universal-translator .
49
+
50
+ # Run the container
51
+ docker run -p 8000:8000 -v ./models:/app/.cache/ct2_models universal-translator
52
+ ```
53
+
54
+ ## 🔁 Converting Models
55
+
56
+ The translation service automatically converts models as needed, but you can pre-convert them using the provided utility:
57
+
58
+ ```bash
59
+ # Convert a specific model
60
+ python ct2_model_converter.py --src en --tgt es --quantization int8
61
+
62
+ # Convert all common language pairs
63
+ python ct2_model_converter.py --all
64
+
65
+ # List available language pairs and quantization options
66
+ python ct2_model_converter.py --list
67
+ ```
68
+
69
+ ### Quantization Options
70
+
71
+ - `int8`: 8-bit integer quantization (best for CPU)
72
+ - `float16`: 16-bit floating point (best for GPU)
73
+ - `int16`: 16-bit integer quantization
74
+ - `float8`: 8-bit floating point (experimental)
75
+ - `auto`: Automatic selection based on device
76
+
77
+ ## 📊 Benchmarking
78
+
79
+ You can benchmark the performance improvements using the provided script:
80
+
81
+ ```bash
82
+ # Run benchmarks for all language pairs
83
+ python benchmark.py
84
+
85
+ # Run benchmarks for specific language pairs
86
+ python benchmark.py --lang-pairs en-es en-fr en-dra
87
+
88
+ # Customize benchmark parameters
89
+ python benchmark.py --runs 10 --warm-up 3 --output custom_results.json
90
+ ```
91
+
92
+ ## 🌐 API Usage
93
+
94
+ ### Text Translation
95
+
96
+ ```python
97
+ import requests
98
+
99
+ data = {
100
+ 'text': 'Hello, how are you today?',
101
+ 'source_lang': 'English',
102
+ 'target_lang': 'Spanish'
103
+ }
104
+
105
+ response = requests.post('http://localhost:8000/translate', data=data)
106
+ print(response.json()['translated_text'])
107
+ ```
108
+
109
+ ### HTML Translation
110
+
111
+ ```python
112
+ import requests
113
+
114
+ data = {
115
+ 'html': '<div><p>Hello, world!</p></div>',
116
+ 'source_lang': 'English',
117
+ 'target_lang': 'French'
118
+ }
119
+
120
+ response = requests.post('http://localhost:8000/translate-html', data=data)
121
+ print(response.json()['translated_html'])
122
+ ```
123
+
124
+ ### Document Translation
125
+
126
+ ```python
127
+ import requests
128
+
129
+ files = {
130
+ 'file': open('document.pdf', 'rb')
131
+ }
132
+
133
+ data = {
134
+ 'source_lang': 'English',
135
+ 'target_lang': 'German',
136
+ 'use_ocr': 'false'
137
+ }
138
+
139
+ response = requests.post('http://localhost:8000/process-document', files=files, data=data)
140
+ print(response.json()['translated_text'])
141
+ ```
142
+
143
+ ## 🌍 Dravidian Language Support
144
+
145
+ For translating to Dravidian languages (Tamil, Telugu, Kannada, Malayalam), the system automatically handles the required special tokens:
146
+
147
+ ```python
148
+ # Tamil translation example
149
+ data = {
150
+ 'text': 'Hello, how are you?',
151
+ 'source_lang': 'English',
152
+ 'target_lang': 'Tamil'
153
+ }
154
+
155
+ response = requests.post('http://localhost:8000/translate', data=data)
156
+ ```
157
+
158
+ The backend adds the special token `>>tam<<` for Tamil, `>>tel<<` for Telugu, `>>kan<<` for Kannada, or `>>mal<<` for Malayalam as required by the Helsinki NLP models.
159
+
160
+ ## 📝 License
161
+
162
+ This project is licensed under the MIT License - see the LICENSE file for details.
163
+
164
+ ## 🙏 Acknowledgements
165
+
166
+ - [Helsinki NLP](https://github.com/Helsinki-NLP) for providing the OPUS-MT models
167
+ - [OpenNMT](https://github.com/OpenNMT/CTranslate2) for the CTranslate2 optimization library
168
+ - [Hugging Face](https://huggingface.co/) for model hosting and Transformers library
169
 
170
  This is a Hugging Face Spaces deployment of the Universal Translator API service, which provides translation capabilities using the MADLAD-400 3B model.
171
 
api_server.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import os
 
3
 
4
  import torch
5
  import uvicorn
@@ -10,7 +11,7 @@ from pydantic import BaseModel
10
  from app.models.document_processor import DocumentProcessor
11
  from app.models.html_processor import HTMLProcessor
12
  from app.models.text_chunker import TextChunker
13
- from app.models.translation_model import TranslationModel
14
 
15
  logging.basicConfig(
16
  level=logging.INFO,
@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
20
 
21
  app = FastAPI(
22
  title="Universal Translator API",
23
- description="API for text, HTML, and document translation services",
24
- version="1.0.0"
25
  )
26
 
27
  app.add_middleware(
@@ -33,11 +34,16 @@ app.add_middleware(
33
  )
34
 
35
  try:
36
- model = TranslationModel()
 
 
37
  html_processor = HTMLProcessor()
38
  text_chunker = TextChunker(max_tokens=250, overlap_tokens=30)
39
  document_processor = DocumentProcessor()
40
 
 
 
 
41
  initialization_error = None
42
  except Exception as e:
43
  logger.error(f"Error initializing components: {str(e)}")
@@ -70,7 +76,11 @@ async def root():
70
  "message": "Service initialization failed",
71
  "error": initialization_error
72
  }
73
- return {"status": "ok", "model": "OPUS-MT/NLLB-CPU-Optimized", "version": "1.0"}
 
 
 
 
74
 
75
  @app.get("/health")
76
  async def health_check():
@@ -81,14 +91,14 @@ async def health_check():
81
  "environment": {
82
  "python_version": os.environ.get('PYTHON_VERSION'),
83
  "cuda_available": torch.cuda.is_available(),
84
- "device": str(model.device) if hasattr(model, 'device') else "Unknown",
85
- "loaded_models": list(model.opus_mt_models.keys()) if hasattr(model, 'opus_mt_models') else []
86
  }
87
  }
88
 
89
  @app.post("/translate", response_model=TranslationResponse)
90
  async def translate_text(request: TranslationRequest):
91
- """Translate text from source to target language"""
92
  if initialization_error:
93
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
94
 
@@ -98,20 +108,25 @@ async def translate_text(request: TranslationRequest):
98
  if request.special_token:
99
  logger.info(f"Using special language token: {request.special_token}")
100
 
101
- chunks = text_chunker.create_chunks(request.text)
102
- translated_chunks = []
103
-
104
- for chunk in chunks:
105
- translated_text = model.translate(
106
- chunk.text,
 
 
 
 
 
 
 
 
 
 
107
  request.source_lang_code,
108
  request.target_lang_code
109
  )
110
- translated_chunks.append(translated_text)
111
-
112
- final_translation = text_chunker.combine_translations(
113
- request.text, chunks, translated_chunks
114
- )
115
 
116
  return {"translated_text": final_translation}
117
  except Exception as e:
@@ -120,7 +135,7 @@ async def translate_text(request: TranslationRequest):
120
 
121
  @app.post("/translate-html", response_model=HTMLTranslationResponse)
122
  async def translate_html(request: HTMLTranslationRequest):
123
- """Translate HTML content while preserving structure"""
124
  if initialization_error:
125
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
126
 
@@ -128,9 +143,8 @@ async def translate_html(request: HTMLTranslationRequest):
128
  text_fragments, dom_data = html_processor.extract_text(request.html)
129
 
130
  if not text_fragments:
131
- return {"translated_html": request.html} # No text to translate
132
 
133
- # Apply special token to each text fragment if needed
134
  if request.special_token:
135
  logger.info(f"Using special language token for HTML: {request.special_token}")
136
  text_fragments = html_processor.prepare_fragments_with_token(
@@ -138,25 +152,31 @@ async def translate_html(request: HTMLTranslationRequest):
138
  request.special_token
139
  )
140
 
141
- translated_fragments = []
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- batch_size = 10
144
- for i in range(0, len(text_fragments), batch_size):
145
- batch = text_fragments[i:i+batch_size]
146
-
147
- for fragment in batch:
148
- if not fragment.strip():
149
- translated_fragments.append(fragment)
150
- continue
151
-
152
- translated_text = model.translate(
153
- fragment,
154
- request.source_lang_code,
155
- request.target_lang_code
156
- )
157
- translated_fragments.append(translated_text)
158
-
159
- translated_html = html_processor.replace_text(dom_data, translated_fragments)
160
 
161
  return {"translated_html": translated_html}
162
  except Exception as e:
@@ -171,7 +191,7 @@ async def process_document(
171
  special_token: str = Form(""),
172
  use_ocr: bool = Form(False)
173
  ):
174
- """Process and translate document (PDF or image)"""
175
  if initialization_error:
176
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
177
 
@@ -189,16 +209,30 @@ async def process_document(
189
  status_code=400,
190
  detail="No text could be extracted from the document"
191
  )
192
-
193
  if special_token:
194
  logger.info(f"Using special language token for document: {special_token}")
195
  extracted_text = f"{special_token}{extracted_text}"
196
 
197
- translated_text = model.translate(
198
- extracted_text,
199
- source_lang_code,
200
- target_lang_code
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  return {
204
  "extracted_text": extracted_text,
 
1
  import logging
2
  import os
3
+ import time
4
 
5
  import torch
6
  import uvicorn
 
11
  from app.models.document_processor import DocumentProcessor
12
  from app.models.html_processor import HTMLProcessor
13
  from app.models.text_chunker import TextChunker
14
+ from app.models.translation_model_ct2 import TranslationModelCT2
15
 
16
  logging.basicConfig(
17
  level=logging.INFO,
 
21
 
22
  app = FastAPI(
23
  title="Universal Translator API",
24
+ description="API for text, HTML, and document translation services with CTranslate2 optimization",
25
+ version="2.0.0"
26
  )
27
 
28
  app.add_middleware(
 
34
  )
35
 
36
  try:
37
+ start_time = time.time()
38
+
39
+ model = TranslationModelCT2(model_cache_dir=os.getenv("CT2_MODEL_CACHE", ".cache/ct2_models"))
40
  html_processor = HTMLProcessor()
41
  text_chunker = TextChunker(max_tokens=250, overlap_tokens=30)
42
  document_processor = DocumentProcessor()
43
 
44
+ initialization_time = time.time() - start_time
45
+ logger.info(f"Initialized components in {initialization_time:.2f}s")
46
+
47
  initialization_error = None
48
  except Exception as e:
49
  logger.error(f"Error initializing components: {str(e)}")
 
76
  "message": "Service initialization failed",
77
  "error": initialization_error
78
  }
79
+ return {
80
+ "status": "ok",
81
+ "model": "CTranslate2 Optimized with MADLAD-400 3B model",
82
+ "version": "2.0"
83
+ }
84
 
85
  @app.get("/health")
86
  async def health_check():
 
91
  "environment": {
92
  "python_version": os.environ.get('PYTHON_VERSION'),
93
  "cuda_available": torch.cuda.is_available(),
94
+ "device": model.device if hasattr(model, 'device') else "Unknown",
95
+ "model_info": model.get_model_info() if hasattr(model, 'get_model_info') else {}
96
  }
97
  }
98
 
99
  @app.post("/translate", response_model=TranslationResponse)
100
  async def translate_text(request: TranslationRequest):
101
+ """Translate text from source to target language using CTranslate2"""
102
  if initialization_error:
103
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
104
 
 
108
  if request.special_token:
109
  logger.info(f"Using special language token: {request.special_token}")
110
 
111
+ if len(request.text) > 1000:
112
+ chunks = text_chunker.create_chunks(request.text)
113
+ chunk_texts = [chunk.text for chunk in chunks]
114
+
115
+ translated_chunks = model.translate_batch(
116
+ chunk_texts,
117
+ request.source_lang_code,
118
+ request.target_lang_code
119
+ )
120
+
121
+ final_translation = text_chunker.combine_translations(
122
+ request.text, chunks, translated_chunks
123
+ )
124
+ else:
125
+ final_translation = model.translate(
126
+ request.text,
127
  request.source_lang_code,
128
  request.target_lang_code
129
  )
 
 
 
 
 
130
 
131
  return {"translated_text": final_translation}
132
  except Exception as e:
 
135
 
136
  @app.post("/translate-html", response_model=HTMLTranslationResponse)
137
  async def translate_html(request: HTMLTranslationRequest):
138
+ """Translate HTML content while preserving structure using CTranslate2"""
139
  if initialization_error:
140
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
141
 
 
143
  text_fragments, dom_data = html_processor.extract_text(request.html)
144
 
145
  if not text_fragments:
146
+ return {"translated_html": request.html}
147
 
 
148
  if request.special_token:
149
  logger.info(f"Using special language token for HTML: {request.special_token}")
150
  text_fragments = html_processor.prepare_fragments_with_token(
 
152
  request.special_token
153
  )
154
 
155
+ non_empty_fragments = []
156
+ empty_indices = []
157
+ for i, fragment in enumerate(text_fragments):
158
+ if fragment.strip():
159
+ non_empty_fragments.append(fragment)
160
+ else:
161
+ empty_indices.append(i)
162
+
163
+ translated_fragments = model.translate_batch(
164
+ non_empty_fragments,
165
+ request.source_lang_code,
166
+ request.target_lang_code
167
+ )
168
 
169
+ full_translated_fragments = []
170
+ non_empty_idx = 0
171
+
172
+ for i in range(len(text_fragments)):
173
+ if i in empty_indices:
174
+ full_translated_fragments.append("")
175
+ else:
176
+ full_translated_fragments.append(translated_fragments[non_empty_idx])
177
+ non_empty_idx += 1
178
+
179
+ translated_html = html_processor.replace_text(dom_data, full_translated_fragments)
 
 
 
 
 
 
180
 
181
  return {"translated_html": translated_html}
182
  except Exception as e:
 
191
  special_token: str = Form(""),
192
  use_ocr: bool = Form(False)
193
  ):
194
+ """Process and translate document (PDF or image) using CTranslate2"""
195
  if initialization_error:
196
  raise HTTPException(status_code=500, detail=f"Service not properly initialized: {initialization_error}")
197
 
 
209
  status_code=400,
210
  detail="No text could be extracted from the document"
211
  )
212
+
213
  if special_token:
214
  logger.info(f"Using special language token for document: {special_token}")
215
  extracted_text = f"{special_token}{extracted_text}"
216
 
217
+ if len(extracted_text) > 1000:
218
+ chunks = text_chunker.create_chunks(extracted_text)
219
+ chunk_texts = [chunk.text for chunk in chunks]
220
+
221
+ translated_chunks = model.translate_batch(
222
+ chunk_texts,
223
+ source_lang_code,
224
+ target_lang_code
225
+ )
226
+
227
+ translated_text = text_chunker.combine_translations(
228
+ extracted_text, chunks, translated_chunks
229
+ )
230
+ else:
231
+ translated_text = model.translate(
232
+ extracted_text,
233
+ source_lang_code,
234
+ target_lang_code
235
+ )
236
 
237
  return {
238
  "extracted_text": extracted_text,
app/models/benchmark_script.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Benchmark script to compare performance between standard Transformers
4
+ and CTranslate2 optimized translation models.
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import os
11
+ import sys
12
+ import time
13
+ from typing import Dict, List, Tuple
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MarianMTModel
19
+
20
+ # Add project root to path for imports
21
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
22
+
23
+ # Configure logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Import our models
31
+ try:
32
+ from app.models.translation_model import TranslationModel # Standard model
33
+ from app.models.translation_model_ct2 import TranslationModelCT2 # CTranslate2 model
34
+ except ImportError:
35
+ logger.error("Could not import translation models. Make sure you're running this script from the project root.")
36
+ sys.exit(1)
37
+
38
+ # Test sentences for various languages
39
+ TEST_SENTENCES = {
40
+ "en-es": [
41
+ "Hello, how are you today?",
42
+ "I would like to book a flight to Madrid for next week.",
43
+ "The quick brown fox jumps over the lazy dog.",
44
+ "Artificial intelligence is transforming the way we live and work.",
45
+ "Please contact our customer service if you have any questions."
46
+ ],
47
+ "en-fr": [
48
+ "Hello, how are you today?",
49
+ "I would like to book a flight to Paris for next week.",
50
+ "The quick brown fox jumps over the lazy dog.",
51
+ "Artificial intelligence is transforming the way we live and work.",
52
+ "Please contact our customer service if you have any questions."
53
+ ],
54
+ "en-de": [
55
+ "Hello, how are you today?",
56
+ "I would like to book a flight to Berlin for next week.",
57
+ "The quick brown fox jumps over the lazy dog.",
58
+ "Artificial intelligence is transforming the way we live and work.",
59
+ "Please contact our customer service if you have any questions."
60
+ ],
61
+ "en-dra": [
62
+ "Hello, how are you today?",
63
+ "I would like to book a flight to Chennai for next week.",
64
+ "The quick brown fox jumps over the lazy dog.",
65
+ "Artificial intelligence is transforming the way we live and work.",
66
+ "Please contact our customer service if you have any questions."
67
+ ]
68
+ }
69
+
70
+ def benchmark_standard_model(
71
+ src_lang: str,
72
+ tgt_lang: str,
73
+ sentences: List[str],
74
+ num_runs: int = 5,
75
+ warm_up: int = 2
76
+ ) -> Dict:
77
+ """Benchmark the standard Transformers model."""
78
+ logger.info(f"Benchmarking standard Transformers model for {src_lang}-{tgt_lang}")
79
+
80
+ # Initialize model
81
+ model = TranslationModel()
82
+
83
+ # Warm-up runs
84
+ logger.info(f"Performing {warm_up} warm-up runs...")
85
+ for _ in range(warm_up):
86
+ for sentence in sentences[:2]: # Use only first 2 sentences for warm-up
87
+ model.translate(sentence, src_lang, tgt_lang)
88
+
89
+ # Actual benchmark
90
+ logger.info(f"Performing {num_runs} benchmark runs...")
91
+ times = []
92
+ translations = []
93
+
94
+ for run in range(num_runs):
95
+ run_times = []
96
+ run_translations = []
97
+
98
+ for sentence in tqdm.tqdm(sentences, desc=f"Run {run+1}/{num_runs}"):
99
+ start_time = time.time()
100
+ translation = model.translate(sentence, src_lang, tgt_lang)
101
+ elapsed_time = time.time() - start_time
102
+
103
+ run_times.append(elapsed_time)
104
+ run_translations.append(translation)
105
+
106
+ times.append(run_times)
107
+
108
+ # Only keep translations from the first run
109
+ if run == 0:
110
+ translations = run_translations
111
+
112
+ # Calculate statistics
113
+ all_times = np.array(times).flatten()
114
+ stats = {
115
+ "mean_time": float(np.mean(all_times)),
116
+ "median_time": float(np.median(all_times)),
117
+ "std_dev": float(np.std(all_times)),
118
+ "min_time": float(np.min(all_times)),
119
+ "max_time": float(np.max(all_times)),
120
+ "total_time": float(np.sum(all_times)),
121
+ "num_sentences": len(sentences) * num_runs,
122
+ "translations": translations
123
+ }
124
+
125
+ return stats
126
+
127
+ def benchmark_ct2_model(
128
+ src_lang: str,
129
+ tgt_lang: str,
130
+ sentences: List[str],
131
+ num_runs: int = 5,
132
+ warm_up: int = 2
133
+ ) -> Dict:
134
+ """Benchmark the CTranslate2 optimized model."""
135
+ logger.info(f"Benchmarking CTranslate2 model for {src_lang}-{tgt_lang}")
136
+
137
+ # Initialize model
138
+ model = TranslationModelCT2()
139
+
140
+ # Warm-up runs
141
+ logger.info(f"Performing {warm_up} warm-up runs...")
142
+ for _ in range(warm_up):
143
+ for sentence in sentences[:2]: # Use only first 2 sentences for warm-up
144
+ model.translate(sentence, src_lang, tgt_lang)
145
+
146
+ # Actual benchmark
147
+ logger.info(f"Performing {num_runs} benchmark runs...")
148
+ times = []
149
+ translations = []
150
+
151
+ for run in range(num_runs):
152
+ run_times = []
153
+ run_translations = []
154
+
155
+ for sentence in tqdm.tqdm(sentences, desc=f"Run {run+1}/{num_runs}"):
156
+ start_time = time.time()
157
+ translation = model.translate(sentence, src_lang, tgt_lang)
158
+ elapsed_time = time.time() - start_time
159
+
160
+ run_times.append(elapsed_time)
161
+ run_translations.append(translation)
162
+
163
+ times.append(run_times)
164
+
165
+ # Only keep translations from the first run
166
+ if run == 0:
167
+ translations = run_translations
168
+
169
+ # Calculate statistics
170
+ all_times = np.array(times).flatten()
171
+ stats = {
172
+ "mean_time": float(np.mean(all_times)),
173
+ "median_time": float(np.median(all_times)),
174
+ "std_dev": float(np.std(all_times)),
175
+ "min_time": float(np.min(all_times)),
176
+ "max_time": float(np.max(all_times)),
177
+ "total_time": float(np.sum(all_times)),
178
+ "num_sentences": len(sentences) * num_runs,
179
+ "translations": translations
180
+ }
181
+
182
+ return stats
183
+
184
+ def benchmark_batch(
185
+ src_lang: str,
186
+ tgt_lang: str,
187
+ sentences: List[str],
188
+ num_runs: int = 5,
189
+ warm_up: int = 2
190
+ ) -> Dict:
191
+ """Benchmark batch translation with CTranslate2."""
192
+ logger.info(f"Benchmarking CTranslate2 batch translation for {src_lang}-{tgt_lang}")
193
+
194
+ # Initialize model
195
+ model = TranslationModelCT2()
196
+
197
+ # Warm-up runs
198
+ logger.info(f"Performing {warm_up} warm-up runs...")
199
+ for _ in range(warm_up):
200
+ model.translate_batch(sentences[:2], src_lang, tgt_lang)
201
+
202
+ # Actual benchmark
203
+ logger.info(f"Performing {num_runs} benchmark runs...")
204
+ times = []
205
+ translations = []
206
+
207
+ for run in range(num_runs):
208
+ start_time = time.time()
209
+ batch_translations = model.translate_batch(sentences, src_lang, tgt_lang)
210
+ elapsed_time = time.time() - start_time
211
+
212
+ times.append(elapsed_time)
213
+
214
+ # Only keep translations from the first run
215
+ if run == 0:
216
+ translations = batch_translations
217
+
218
+ # Calculate statistics
219
+ stats = {
220
+ "mean_time": float(np.mean(times)),
221
+ "median_time": float(np.median(times)),
222
+ "std_dev": float(np.std(times)),
223
+ "min_time": float(np.min(times)),
224
+ "max_time": float(np.max(times)),
225
+ "total_time": float(np.sum(times)),
226
+ "num_sentences": len(sentences),
227
+ "num_batches": num_runs,
228
+ "translations": translations
229
+ }
230
+
231
+ return stats
232
+
233
+ def run_benchmarks(
234
+ lang_pairs: List[Tuple[str, str]],
235
+ num_runs: int = 5,
236
+ warm_up: int = 2,
237
+ output_file: str = "benchmark_results.json"
238
+ ) -> Dict:
239
+ """Run benchmarks for specified language pairs."""
240
+ device = "cuda" if torch.cuda.is_available() else "cpu"
241
+ logger.info(f"Running benchmarks on {device}")
242
+
243
+ results = {
244
+ "device": device,
245
+ "cuda_available": torch.cuda.is_available(),
246
+ "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
247
+ "num_runs": num_runs,
248
+ "warm_up_runs": warm_up,
249
+ "language_pairs": {}
250
+ }
251
+
252
+ for src_lang, tgt_lang in lang_pairs:
253
+ model_key = f"{src_lang}-{tgt_lang}"
254
+
255
+ if model_key not in TEST_SENTENCES:
256
+ logger.warning(f"No test sentences available for {model_key}, skipping...")
257
+ continue
258
+
259
+ logger.info(f"Benchmarking {model_key}...")
260
+
261
+ sentences = TEST_SENTENCES[model_key]
262
+
263
+ # Run standard model benchmark
264
+ standard_stats = benchmark_standard_model(
265
+ src_lang, tgt_lang, sentences, num_runs, warm_up
266
+ )
267
+
268
+ # Run CTranslate2 model benchmark
269
+ ct2_stats = benchmark_ct2_model(
270
+ src_lang, tgt_lang, sentences, num_runs, warm_up
271
+ )
272
+
273
+ # Run batch translation benchmark
274
+ batch_stats = benchmark_batch(
275
+ src_lang, tgt_lang, sentences, num_runs, warm_up
276
+ )
277
+
278
+ # Calculate speedup
279
+ speedup = standard_stats["mean_time"] / ct2_stats["mean_time"]
280
+ batch_speedup = standard_stats["mean_time"] * len(sentences) / batch_stats["mean_time"]
281
+
282
+ results["language_pairs"][model_key] = {
283
+ "standard_model": standard_stats,
284
+ "ct2_model": ct2_stats,
285
+ "batch_translation": batch_stats,
286
+ "speedup": float(speedup),
287
+ "batch_speedup": float(batch_speedup)
288
+ }
289
+
290
+ # Print summary
291
+ logger.info(f"\nResults for {model_key}:")
292
+ logger.info(f" Standard model average time: {standard_stats['mean_time']:.4f}s")
293
+ logger.info(f" CTranslate2 model average time: {ct2_stats['mean_time']:.4f}s")
294
+ logger.info(f" Batch translation average time: {batch_stats['mean_time']:.4f}s (for {len(sentences)} sentences)")
295
+ logger.info(f" Speedup: {speedup:.2f}x")
296
+ logger.info(f" Batch speedup: {batch_speedup:.2f}x")
297
+
298
+ # Save results to file
299
+ with open(output_file, "w") as f:
300
+ json.dump(results, f, indent=2)
301
+
302
+ logger.info(f"Benchmark results saved to {output_file}")
303
+
304
+ return results
305
+
306
+ def main():
307
+ """Main entry point for the benchmark script."""
308
+ parser = argparse.ArgumentParser(
309
+ description="Benchmark translation models performance"
310
+ )
311
+
312
+ parser.add_argument(
313
+ "--lang-pairs",
314
+ type=str,
315
+ nargs="+",
316
+ default=["en-es", "en-fr", "en-de", "en-dra"],
317
+ help="Language pairs to benchmark (e.g., 'en-es en-fr')"
318
+ )
319
+ parser.add_argument(
320
+ "--runs",
321
+ type=int,
322
+ default=5,
323
+ help="Number of benchmark runs"
324
+ )
325
+ parser.add_argument(
326
+ "--warm-up",
327
+ type=int,
328
+ default=2,
329
+ help="Number of warm-up runs"
330
+ )
331
+ parser.add_argument(
332
+ "--output",
333
+ type=str,
334
+ default="benchmark_results.json",
335
+ help="Output file for benchmark results"
336
+ )
337
+
338
+ args = parser.parse_args()
339
+
340
+ # Parse language pairs
341
+ lang_pairs = []
342
+ for pair in args.lang_pairs:
343
+ if "-" in pair:
344
+ src, tgt = pair.split("-")
345
+ lang_pairs.append((src, tgt))
346
+ else:
347
+ logger.warning(f"Invalid language pair format: {pair}, skipping...")
348
+
349
+ if not lang_pairs:
350
+ logger.error("No valid language pairs specified")
351
+ return 1
352
+
353
+ # Run benchmarks
354
+ run_benchmarks(
355
+ lang_pairs=lang_pairs,
356
+ num_runs=args.runs,
357
+ warm_up=args.warm_up,
358
+ output_file=args.output
359
+ )
360
+
361
+ return 0
362
+
363
+ if __name__ == "__main__":
364
+ sys.exit(main())
app/models/ct2_model_converter.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Utility script to convert Helsinki NLP Opus MT models to CTranslate2 format.
4
+ This script handles the special case of Dravidian languages.
5
+ """
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import sys
11
+ from typing import Dict, List, Optional, Set
12
+
13
+ import torch
14
+
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Common language pairs
23
+ COMMON_LANGUAGE_PAIRS = [
24
+ ("en", "es"), # English to Spanish
25
+ ("en", "fr"), # English to French
26
+ ("en", "de"), # English to German
27
+ ("en", "ru"), # English to Russian
28
+ ("en", "zh"), # English to Chinese
29
+ ("en", "ar"), # English to Arabic
30
+ ("en", "hi"), # English to Hindi
31
+ ("en", "dra"), # English to Dravidian languages
32
+ ("es", "en"), # Spanish to English
33
+ ("fr", "en"), # French to English
34
+ ("de", "en"), # German to English
35
+ ("ru", "en"), # Russian to English
36
+ ("zh", "en"), # Chinese to English
37
+ ("ar", "en"), # Arabic to English
38
+ ("hi", "en"), # Hindi to English
39
+ ]
40
+
41
+ # Supported quantization types
42
+ QUANTIZATION_TYPES = {
43
+ "int8": "8-bit integer quantization (best for CPU)",
44
+ "int16": "16-bit integer quantization",
45
+ "float16": "16-bit floating point (best for modern GPUs)",
46
+ "float8": "8-bit floating point (experimental)",
47
+ "auto": "Automatic selection based on device",
48
+ }
49
+
50
+ def get_device() -> str:
51
+ """Get the best available device for model inference."""
52
+ if torch.cuda.is_available():
53
+ return "cuda"
54
+ else:
55
+ return "cpu"
56
+
57
+ def get_auto_quantization(device: str) -> str:
58
+ """Get the appropriate quantization based on device."""
59
+ if device == "cuda":
60
+ return "float16"
61
+ else:
62
+ return "int8"
63
+
64
+ def get_huggingface_model_name(src_lang: str, tgt_lang: str) -> str:
65
+ """Get the appropriate HuggingFace model name for the language pair."""
66
+ return f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
67
+
68
+ def convert_model(
69
+ src_lang: str,
70
+ tgt_lang: str,
71
+ output_dir: str,
72
+ quantization: str = "auto",
73
+ force: bool = False
74
+ ) -> bool:
75
+ """
76
+ Convert a Helsinki NLP model to CTranslate2 format.
77
+
78
+ Args:
79
+ src_lang: Source language code
80
+ tgt_lang: Target language code
81
+ output_dir: Output directory path
82
+ quantization: Quantization type
83
+ force: Whether to force conversion if model exists
84
+
85
+ Returns:
86
+ bool: Success status
87
+ """
88
+ try:
89
+ # Determine output path
90
+ model_key = f"{src_lang}-{tgt_lang}"
91
+ model_dir = os.path.join(output_dir, f"ct2_{model_key}")
92
+
93
+ # Check if model already exists
94
+ if os.path.exists(model_dir) and os.path.isdir(model_dir) and not force:
95
+ logger.info(f"Model {model_key} already exists at {model_dir}. Use --force to overwrite.")
96
+ return True
97
+
98
+ # Get the HuggingFace model name
99
+ huggingface_model = get_huggingface_model_name(src_lang, tgt_lang)
100
+ logger.info(f"Converting model {huggingface_model} to CTranslate2 format")
101
+
102
+ # Handle auto quantization
103
+ device = get_device()
104
+ if quantization == "auto":
105
+ quantization = get_auto_quantization(device)
106
+
107
+ logger.info(f"Using {quantization} quantization for {device} device")
108
+
109
+ try:
110
+ # Import here to avoid dependency if not installed
111
+ from ctranslate2.converters import TransformersConverter
112
+
113
+ # Create converter
114
+ converter = TransformersConverter(huggingface_model)
115
+
116
+ # Convert model
117
+ converter.convert(
118
+ model_dir,
119
+ quantization=quantization,
120
+ force=True
121
+ )
122
+
123
+ logger.info(f"Successfully converted {huggingface_model} to CTranslate2 format at {model_dir}")
124
+ return True
125
+
126
+ except ImportError:
127
+ logger.warning("Could not import TransformersConverter, falling back to command line")
128
+
129
+ # Fallback to command line
130
+ import subprocess
131
+ cmd = [
132
+ "ct2-transformers-converter",
133
+ "--model", huggingface_model,
134
+ "--output_dir", model_dir,
135
+ "--quantization", quantization,
136
+ "--force"
137
+ ]
138
+
139
+ # Run the command
140
+ logger.info(f"Running command: {' '.join(cmd)}")
141
+ result = subprocess.run(cmd, capture_output=True, text=True)
142
+
143
+ if result.returncode == 0:
144
+ logger.info(f"Successfully converted model using shell command")
145
+ return True
146
+ else:
147
+ logger.error(f"Error in shell command: {result.stderr}")
148
+ return False
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error converting model {src_lang}-{tgt_lang}: {str(e)}")
152
+ return False
153
+
154
+ def convert_all_models(
155
+ output_dir: str,
156
+ quantization: str = "auto",
157
+ force: bool = False
158
+ ) -> Dict[str, bool]:
159
+ """
160
+ Convert all common language pair models to CTranslate2 format.
161
+
162
+ Args:
163
+ output_dir: Output directory path
164
+ quantization: Quantization type
165
+ force: Whether to force conversion if model exists
166
+
167
+ Returns:
168
+ Dict[str, bool]: Results by language pair
169
+ """
170
+ results = {}
171
+
172
+ for src_lang, tgt_lang in COMMON_LANGUAGE_PAIRS:
173
+ model_key = f"{src_lang}-{tgt_lang}"
174
+ logger.info(f"Processing model pair: {model_key}")
175
+
176
+ success = convert_model(
177
+ src_lang=src_lang,
178
+ tgt_lang=tgt_lang,
179
+ output_dir=output_dir,
180
+ quantization=quantization,
181
+ force=force
182
+ )
183
+
184
+ results[model_key] = success
185
+
186
+ # Print summary
187
+ logger.info("\n=== Conversion Summary ===")
188
+ success_count = sum(1 for success in results.values() if success)
189
+ logger.info(f"Successfully converted {success_count} of {len(results)} models")
190
+
191
+ for model_key, success in results.items():
192
+ status = "✓" if success else "✗"
193
+ logger.info(f"{status} {model_key}")
194
+
195
+ return results
196
+
197
+ def main():
198
+ """Main entry point for the script."""
199
+ parser = argparse.ArgumentParser(
200
+ description="Convert Helsinki NLP Opus MT models to CTranslate2 format"
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--src",
205
+ type=str,
206
+ help="Source language code (e.g., 'en')"
207
+ )
208
+ parser.add_argument(
209
+ "--tgt",
210
+ type=str,
211
+ help="Target language code (e.g., 'es', 'fr', 'dra')"
212
+ )
213
+ parser.add_argument(
214
+ "--output-dir",
215
+ type=str,
216
+ default=".cache/ct2_models",
217
+ help="Output directory for converted models"
218
+ )
219
+ parser.add_argument(
220
+ "--quantization",
221
+ type=str,
222
+ choices=list(QUANTIZATION_TYPES.keys()),
223
+ default="auto",
224
+ help="Quantization type to use"
225
+ )
226
+ parser.add_argument(
227
+ "--force",
228
+ action="store_true",
229
+ help="Force conversion even if model exists"
230
+ )
231
+ parser.add_argument(
232
+ "--all",
233
+ action="store_true",
234
+ help="Convert all common language pairs"
235
+ )
236
+ parser.add_argument(
237
+ "--list",
238
+ action="store_true",
239
+ help="List all common language pairs"
240
+ )
241
+
242
+ args = parser.parse_args()
243
+
244
+ # Make sure output directory exists
245
+ os.makedirs(args.output_dir, exist_ok=True)
246
+
247
+ # List common language pairs if requested
248
+ if args.list:
249
+ print("\nCommon language pairs:")
250
+ for src, tgt in COMMON_LANGUAGE_PAIRS:
251
+ print(f" {src}-{tgt}")
252
+ print("\nQuantization types:")
253
+ for q_type, desc in QUANTIZATION_TYPES.items():
254
+ print(f" {q_type}: {desc}")
255
+ return 0
256
+
257
+ # Convert all models if requested
258
+ if args.all:
259
+ results = convert_all_models(
260
+ output_dir=args.output_dir,
261
+ quantization=args.quantization,
262
+ force=args.force
263
+ )
264
+ return 0 if all(results.values()) else 1
265
+
266
+ # Otherwise, need source and target languages
267
+ if not args.src or not args.tgt:
268
+ parser.error("--src and --tgt are required unless --all or --list is specified")
269
+
270
+ # Convert single model
271
+ success = convert_model(
272
+ src_lang=args.src,
273
+ tgt_lang=args.tgt,
274
+ output_dir=args.output_dir,
275
+ quantization=args.quantization,
276
+ force=args.force
277
+ )
278
+
279
+ return 0 if success else 1
280
+
281
+ if __name__ == "__main__":
282
+ sys.exit(main())
app/models/translation_model_ct2.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import ctranslate2
7
+ import torch
8
+ import transformers
9
+ from transformers import AutoTokenizer
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class TranslationModelCT2:
14
+ """
15
+ Optimized translation model using CTranslate2 for faster inference.
16
+ """
17
+
18
+ def __init__(self, model_cache_dir: str = ".cache/ct2_models"):
19
+ """
20
+ Initialize the CTranslate2 translation model manager.
21
+
22
+ Args:
23
+ model_cache_dir: Directory to cache converted models
24
+ """
25
+ self.model_cache_dir = model_cache_dir
26
+ self.device = self._get_device()
27
+ self.ct2_models = {} # Cache for loaded CTranslate2 models
28
+ self.tokenizers = {} # Cache for tokenizers
29
+ self.model_paths = {} # Map for model paths
30
+ self.initialized = False
31
+ self.initialization_error = None
32
+
33
+ # Create cache directory
34
+ os.makedirs(model_cache_dir, exist_ok=True)
35
+
36
+ try:
37
+ # Log available device
38
+ logger.info(f"TranslationModelCT2 initialized with device: {self.device}")
39
+ self.initialized = True
40
+ except Exception as e:
41
+ self.initialization_error = str(e)
42
+ logger.error(f"Failed to initialize CTranslate2 translation model: {str(e)}")
43
+
44
+ def _get_device(self) -> str:
45
+ """Get the best available device for model inference."""
46
+ if torch.cuda.is_available():
47
+ logger.info("Using CUDA GPU for CTranslate2")
48
+ return "cuda"
49
+ else:
50
+ logger.info("Using CPU for CTranslate2")
51
+ return "cpu"
52
+
53
+ def _get_compute_type(self) -> str:
54
+ """Get the appropriate compute type based on device."""
55
+ if self.device == "cuda":
56
+ return "int8_float16" # More efficient for GPU
57
+ else:
58
+ return "int8" # More efficient for CPU
59
+
60
+ def _get_model_key(self, source_lang_code: str, target_lang_code: str) -> str:
61
+ """Create a unique key for the model cache."""
62
+ return f"{source_lang_code}-{target_lang_code}"
63
+
64
+ def _get_huggingface_model_name(self, source_lang_code: str, target_lang_code: str) -> str:
65
+ """Get the appropriate HuggingFace model name for the language pair."""
66
+ # Handle special case for Dravidian languages
67
+ if target_lang_code == "dra":
68
+ return "Helsinki-NLP/opus-mt-en-dra"
69
+
70
+ # Standard language pairs
71
+ return f"Helsinki-NLP/opus-mt-{source_lang_code}-{target_lang_code}"
72
+
73
+ def _get_ct2_model_path(self, source_lang_code: str, target_lang_code: str) -> str:
74
+ """Get the path for the CTranslate2 model."""
75
+ model_key = self._get_model_key(source_lang_code, target_lang_code)
76
+ return os.path.join(self.model_cache_dir, f"ct2_{model_key}")
77
+
78
+ def _convert_model_if_needed(self, source_lang_code: str, target_lang_code: str) -> str:
79
+ """Convert the model to CTranslate2 format if not already converted."""
80
+ model_key = self._get_model_key(source_lang_code, target_lang_code)
81
+ model_path = self._get_ct2_model_path(source_lang_code, target_lang_code)
82
+
83
+ # Check if model already exists
84
+ if os.path.exists(model_path) and os.path.isdir(model_path):
85
+ logger.info(f"Using existing CTranslate2 model for {model_key}")
86
+ return model_path
87
+
88
+ # Get the Hugging Face model name
89
+ huggingface_model = self._get_huggingface_model_name(source_lang_code, target_lang_code)
90
+ logger.info(f"Converting model {huggingface_model} to CTranslate2 format")
91
+
92
+ try:
93
+ # Import here to avoid dependency if ct2-transformers-converter not used
94
+ from ctranslate2.converters import TransformersConverter
95
+
96
+ # Create converter
97
+ converter = TransformersConverter(huggingface_model)
98
+
99
+ # Convert model
100
+ converter.convert(
101
+ model_path,
102
+ quantization=self._get_compute_type().split("_")[0], # int8 or float16
103
+ force=True
104
+ )
105
+
106
+ logger.info(f"Successfully converted {huggingface_model} to CTranslate2 format at {model_path}")
107
+ return model_path
108
+ except Exception as e:
109
+ logger.error(f"Error converting model to CTranslate2 format: {str(e)}")
110
+
111
+ # Fallback - use shell command to convert
112
+ try:
113
+ logger.info(f"Attempting conversion using ct2-transformers-converter shell command")
114
+
115
+ import subprocess
116
+ cmd = [
117
+ "ct2-transformers-converter",
118
+ "--model", huggingface_model,
119
+ "--output_dir", model_path,
120
+ "--quantization", self._get_compute_type().split("_")[0],
121
+ "--force"
122
+ ]
123
+
124
+ # Run the command
125
+ result = subprocess.run(cmd, capture_output=True, text=True)
126
+
127
+ if result.returncode == 0:
128
+ logger.info(f"Successfully converted model using shell command")
129
+ return model_path
130
+ else:
131
+ logger.error(f"Error in shell command: {result.stderr}")
132
+ raise ValueError(f"Failed to convert model: {result.stderr}")
133
+
134
+ except Exception as shell_error:
135
+ logger.error(f"Error with shell conversion: {str(shell_error)}")
136
+ raise ValueError(f"Could not convert model {huggingface_model} to CTranslate2 format")
137
+
138
+ def _load_model(self, source_lang_code: str, target_lang_code: str) -> Tuple[ctranslate2.Translator, transformers.PreTrainedTokenizer]:
139
+ """Load a CTranslate2 model and tokenizer for the language pair."""
140
+ model_key = self._get_model_key(source_lang_code, target_lang_code)
141
+
142
+ # Check if already loaded
143
+ if model_key in self.ct2_models and model_key in self.tokenizers:
144
+ return self.ct2_models[model_key], self.tokenizers[model_key]
145
+
146
+ try:
147
+ # Convert model if needed
148
+ model_path = self._convert_model_if_needed(source_lang_code, target_lang_code)
149
+
150
+ # Load the tokenizer
151
+ huggingface_model = self._get_huggingface_model_name(source_lang_code, target_lang_code)
152
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
153
+
154
+ # Load CTranslate2 model
155
+ inter_threads = 1 # Number of parallel translations
156
+ intra_threads = min(os.cpu_count() or 4, 4) # Number of threads per translation
157
+
158
+ translator = ctranslate2.Translator(
159
+ model_path,
160
+ device=self.device,
161
+ compute_type=self._get_compute_type(),
162
+ inter_threads=inter_threads,
163
+ intra_threads=intra_threads
164
+ )
165
+
166
+ # Cache the model and tokenizer
167
+ self.ct2_models[model_key] = translator
168
+ self.tokenizers[model_key] = tokenizer
169
+ self.model_paths[model_key] = model_path
170
+
171
+ logger.info(f"Successfully loaded CTranslate2 model and tokenizer for {model_key}")
172
+ return translator, tokenizer
173
+
174
+ except Exception as e:
175
+ logger.error(f"Error loading CTranslate2 model: {str(e)}")
176
+ raise
177
+
178
+ def translate(self, text: str, source_lang_code: str, target_lang_code: str) -> str:
179
+ """
180
+ Translate text from source language to target language using CTranslate2.
181
+
182
+ Args:
183
+ text: Text to translate
184
+ source_lang_code: Source language code
185
+ target_lang_code: Target language code
186
+
187
+ Returns:
188
+ Translated text
189
+ """
190
+ if not text.strip():
191
+ return ""
192
+
193
+ try:
194
+ if not self.initialized:
195
+ raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
196
+
197
+ # Handle special tokens in text (for Dravidian languages)
198
+ # We don't need to modify the target_lang_code since the special token is already in the text
199
+
200
+ # Load the model and tokenizer
201
+ translator, tokenizer = self._load_model(source_lang_code, target_lang_code)
202
+
203
+ # Tokenize the input text
204
+ tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
205
+
206
+ # Translate using CTranslate2
207
+ results = translator.translate_batch([tokens])
208
+
209
+ # The first result's first hypothesis
210
+ target_tokens = results[0].hypotheses[0]
211
+
212
+ # Convert tokens back to text
213
+ translated_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(target_tokens))
214
+
215
+ # Clean up the output
216
+ return re.sub(r'\s+', ' ', translated_text).strip()
217
+
218
+ except Exception as e:
219
+ logger.error(f"CTranslate2 translation error: {str(e)}")
220
+ raise
221
+
222
+ def translate_batch(self, texts: List[str], source_lang_code: str, target_lang_code: str) -> List[str]:
223
+ """
224
+ Translate a batch of texts for improved performance.
225
+
226
+ Args:
227
+ texts: List of texts to translate
228
+ source_lang_code: Source language code
229
+ target_lang_code: Target language code
230
+
231
+ Returns:
232
+ List of translated texts
233
+ """
234
+ if not texts:
235
+ return []
236
+
237
+ try:
238
+ if not self.initialized:
239
+ raise ValueError(f"Translation model not properly initialized: {self.initialization_error}")
240
+
241
+ # Load the model and tokenizer
242
+ translator, tokenizer = self._load_model(source_lang_code, target_lang_code)
243
+
244
+ # Tokenize all input texts
245
+ tokens_batch = [
246
+ tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
247
+ for text in texts
248
+ ]
249
+
250
+ # Translate the batch
251
+ results = translator.translate_batch(tokens_batch)
252
+
253
+ # Extract the translations
254
+ translated_texts = []
255
+ for result in results:
256
+ if result.hypotheses:
257
+ target_tokens = result.hypotheses[0]
258
+ translated_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(target_tokens))
259
+ translated_text = re.sub(r'\s+', ' ', translated_text).strip()
260
+ translated_texts.append(translated_text)
261
+ else:
262
+ translated_texts.append("")
263
+
264
+ return translated_texts
265
+
266
+ except Exception as e:
267
+ logger.error(f"CTranslate2 batch translation error: {str(e)}")
268
+ raise
269
+
270
+ def get_model_info(self) -> Dict:
271
+ """Get information about loaded models."""
272
+ return {
273
+ "device": self.device,
274
+ "compute_type": self._get_compute_type(),
275
+ "loaded_models": list(self.ct2_models.keys()),
276
+ "model_paths": self.model_paths
277
+ }
requirements.txt CHANGED
@@ -12,4 +12,6 @@ tqdm
12
  beautifulsoup4
13
  PyMuPDF
14
  protobuf
15
- torch
 
 
 
12
  beautifulsoup4
13
  PyMuPDF
14
  protobuf
15
+ torch
16
+ ctranslate2
17
+ hf-hub-ctranslate2