froster02 commited on
Commit
6919c1d
·
1 Parent(s): f7aacd7

feat: optimize for HF Spaces deployment

Browse files
Files changed (4) hide show
  1. Dockerfile +14 -6
  2. backend/app.py +14 -2
  3. backend/download_models.py +14 -0
  4. backend/models.py +59 -17
Dockerfile CHANGED
@@ -28,19 +28,27 @@ RUN pip install --no-cache-dir -r backend/requirements.txt
28
  # Copy all source files
29
  COPY backend/ ./backend/
30
 
31
- # Pre-download and bake models inside the Docker image during the build stage
32
- # This caches them permanently in the image so they are ready instantly on startup
33
- RUN python backend/download_models.py backend/models
 
 
 
 
 
 
 
 
34
 
35
  # Copy the built React assets from Stage 1 into the backend's static folder
36
  COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
37
 
38
- # Create a non-root user with UID 1000 (standard for Hugging Face Spaces) and set directory ownerships
39
- RUN useradd -m -u 1000 user && \
40
- chown -R 1000:1000 /app
41
 
42
  # Switch to the non-root user
43
  USER user
 
44
  ENV PATH="/home/user/.local/bin:$PATH"
45
 
46
  # Expose port 7860
 
28
  # Copy all source files
29
  COPY backend/ ./backend/
30
 
31
+ # Create a non-root user with UID 1000 (standard for Hugging Face Spaces)
32
+ RUN useradd -m -u 1000 user
33
+
34
+ # Set environment variables for model caching in a writable location
35
+ ENV HF_HOME=/app/backend/models/hf_cache
36
+ ENV EASYOCR_MODULE_PATH=/app/backend/models/easyocr
37
+
38
+ # Pre-download and bake models inside the Docker image
39
+ # We run this as root but ensure the directory exists and will be chowned
40
+ RUN mkdir -p /app/backend/models/easyocr && \
41
+ python backend/download_models.py backend/models
42
 
43
  # Copy the built React assets from Stage 1 into the backend's static folder
44
  COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
45
 
46
+ # Set directory ownerships for the non-root user
47
+ RUN chown -R 1000:1000 /app
 
48
 
49
  # Switch to the non-root user
50
  USER user
51
+ ENV HOME=/home/user
52
  ENV PATH="/home/user/.local/bin:$PATH"
53
 
54
  # Expose port 7860
backend/app.py CHANGED
@@ -1,11 +1,21 @@
1
  import os
2
  import shutil
3
  import uuid
 
4
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import FileResponse, StreamingResponse
7
  from fastapi.staticfiles import StaticFiles
8
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
9
  try:
10
  from langdetect import detect, DetectorFactory
11
  DetectorFactory.seed = 0
@@ -36,10 +46,12 @@ def clean_temp_folder():
36
  @asynccontextmanager
37
  async def lifespan(app: FastAPI):
38
  # Startup logic
 
39
  clean_temp_folder()
40
- print("[*] Temporary folder cleared.")
41
  yield
42
  # Shutdown logic (if any)
 
43
 
44
  app = FastAPI(title="Offline Translation API", version="1.0.0", lifespan=lifespan)
45
 
@@ -97,7 +109,7 @@ class TTSRequest(BaseModel):
97
 
98
  @app.get("/health")
99
  def health_check():
100
- print("[*] Health check hit")
101
  return {"status": "healthy"}
102
 
103
  @app.get("/ping")
 
1
  import os
2
  import shutil
3
  import uuid
4
+ import logging
5
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.responses import FileResponse, StreamingResponse
8
  from fastapi.staticfiles import StaticFiles
9
  from pydantic import BaseModel
10
+
11
+ # Configure logging
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
15
+ handlers=[logging.StreamHandler()]
16
+ )
17
+ logger = logging.getLogger("baif-api")
18
+
19
  try:
20
  from langdetect import detect, DetectorFactory
21
  DetectorFactory.seed = 0
 
46
  @asynccontextmanager
47
  async def lifespan(app: FastAPI):
48
  # Startup logic
49
+ logger.info("Application starting up...")
50
  clean_temp_folder()
51
+ logger.info("Temporary folder cleared.")
52
  yield
53
  # Shutdown logic (if any)
54
+ logger.info("Application shutting down...")
55
 
56
  app = FastAPI(title="Offline Translation API", version="1.0.0", lifespan=lifespan)
57
 
 
109
 
110
  @app.get("/health")
111
  def health_check():
112
+ logger.info("Health check hit")
113
  return {"status": "healthy"}
114
 
115
  @app.get("/ping")
backend/download_models.py CHANGED
@@ -54,6 +54,20 @@ def download_models(target_dir="./models"):
54
  except Exception as e:
55
  print(f"[✗] Error downloading TTS {model_id}: {e}", file=sys.stderr)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  print("\n[✓] All models downloaded successfully and cached for offline use!")
58
 
59
  if __name__ == "__main__":
 
54
  except Exception as e:
55
  print(f"[✗] Error downloading TTS {model_id}: {e}", file=sys.stderr)
56
 
57
+ # 4. EasyOCR Models (Marathi, Hindi, English)
58
+ print("\n[+] Downloading EasyOCR Models...")
59
+ try:
60
+ import easyocr
61
+ # Set EASYOCR_MODULE_PATH to make sure it downloads to the right place if the env var is not yet picked up
62
+ if "EASYOCR_MODULE_PATH" not in os.environ:
63
+ os.environ["EASYOCR_MODULE_PATH"] = os.path.join(target_dir, "easyocr")
64
+
65
+ # This will trigger the download of models for the specified languages
66
+ reader = easyocr.Reader(['hi', 'mr', 'en'], gpu=False)
67
+ print("[✓] Successfully downloaded EasyOCR models")
68
+ except Exception as e:
69
+ print(f"[✗] Error downloading EasyOCR models: {e}", file=sys.stderr)
70
+
71
  print("\n[✓] All models downloaded successfully and cached for offline use!")
72
 
73
  if __name__ == "__main__":
backend/models.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import numpy as np
4
  import soundfile as sf
5
  import threading
 
6
  from transformers import (
7
  pipeline,
8
  AutoModelForSeq2SeqLM,
@@ -12,6 +13,10 @@ from transformers import (
12
  WhisperForConditionalGeneration
13
  )
14
 
 
 
 
 
15
  class ModelManager:
16
  def __init__(self, cache_dir="./models"):
17
  self.cache_dir = os.path.abspath(cache_dir)
@@ -30,6 +35,7 @@ class ModelManager:
30
  self.device = "cpu"
31
 
32
  print(f"[*] ModelManager initialized using device: {self.device} (CI_MODE={self.ci_mode})")
 
33
 
34
  # Lazy load containers
35
  self.whisper_pipe = {}
@@ -38,25 +44,45 @@ class ModelManager:
38
  self.tts_models = {}
39
  self.tts_tokenizers = {}
40
 
 
 
 
 
 
 
 
 
41
  def get_whisper(self, size="base"):
42
  with self.lock:
43
  if size not in self.whisper_pipe:
44
  model_id = f"openai/whisper-{size}"
45
  print(f"[*] Loading STT model {model_id} from {self.cache_dir} on {self.device}...")
46
 
47
- # Load processor & model from local cache
48
- processor = WhisperProcessor.from_pretrained(model_id, cache_dir=self.cache_dir)
49
- model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=self.cache_dir)
50
-
51
- # Pipeline does chunking automatically for long files
52
- self.whisper_pipe[size] = pipeline(
53
- "automatic-speech-recognition",
54
- model=model,
55
- tokenizer=processor.tokenizer,
56
- feature_extractor=processor.feature_extractor,
57
- chunk_length_s=30,
58
- device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps")
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
60
  return self.whisper_pipe[size]
61
 
62
  def get_nllb(self):
@@ -64,8 +90,14 @@ class ModelManager:
64
  if self.nllb_model is None:
65
  model_id = "facebook/nllb-200-distilled-600M"
66
  print(f"[*] Loading NLLB-200 translation model from {self.cache_dir} on {self.device}...")
67
- self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
68
- self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
 
 
 
 
 
 
69
  return self.nllb_model, self.nllb_tokenizer
70
 
71
  def get_tts(self, lang):
@@ -81,8 +113,14 @@ class ModelManager:
81
  raise ValueError(f"Unsupported TTS language: {lang}")
82
 
83
  print(f"[*] Loading TTS model for {lang} ({model_id}) on {self.device}...")
84
- self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
85
- self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
 
 
 
 
 
 
86
 
87
  return self.tts_models[lang], self.tts_tokenizers[lang]
88
 
@@ -122,6 +160,7 @@ class ModelManager:
122
  stride_length_s=5,
123
  generate_kwargs=gen_kwargs
124
  )
 
125
 
126
  # Extract segments from chunks
127
  chunks = result.get("chunks", [])
@@ -193,6 +232,7 @@ class ModelManager:
193
  )
194
 
195
  translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
 
196
  return translated_text
197
 
198
  def translate_batch(self, texts, src_lang, tgt_lang):
@@ -252,6 +292,7 @@ class ModelManager:
252
  )
253
 
254
  translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
 
255
 
256
  # Map back to full results list
257
  for i, idx in enumerate(non_empty_indices):
@@ -302,5 +343,6 @@ class ModelManager:
302
 
303
  # MMS-TTS models output sample rate is 16000Hz
304
  sf.write(output_path, waveform_numpy, samplerate=16000)
 
305
  print(f"[✓] TTS audio written to: {output_path}")
306
  return output_path
 
3
  import numpy as np
4
  import soundfile as sf
5
  import threading
6
+ import gc
7
  from transformers import (
8
  pipeline,
9
  AutoModelForSeq2SeqLM,
 
13
  WhisperForConditionalGeneration
14
  )
15
 
16
+ # Optimize Torch for CPU-only environments like HF Spaces
17
+ if not torch.cuda.is_available():
18
+ torch.set_num_threads(int(os.cpu_count() or 1))
19
+
20
  class ModelManager:
21
  def __init__(self, cache_dir="./models"):
22
  self.cache_dir = os.path.abspath(cache_dir)
 
35
  self.device = "cpu"
36
 
37
  print(f"[*] ModelManager initialized using device: {self.device} (CI_MODE={self.ci_mode})")
38
+ print(f"[*] Cache directory: {self.cache_dir}")
39
 
40
  # Lazy load containers
41
  self.whisper_pipe = {}
 
44
  self.tts_models = {}
45
  self.tts_tokenizers = {}
46
 
47
+ def _clear_memory(self):
48
+ """Force garbage collection and clear torch cache if on GPU"""
49
+ gc.collect()
50
+ if torch.cuda.is_available():
51
+ torch.cuda.empty_cache()
52
+ elif self.device == "mps":
53
+ torch.mps.empty_cache()
54
+
55
  def get_whisper(self, size="base"):
56
  with self.lock:
57
  if size not in self.whisper_pipe:
58
  model_id = f"openai/whisper-{size}"
59
  print(f"[*] Loading STT model {model_id} from {self.cache_dir} on {self.device}...")
60
 
61
+ try:
62
+ # Load processor & model from local cache
63
+ processor = WhisperProcessor.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
64
+ model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
65
+
66
+ # Pipeline does chunking automatically for long files
67
+ self.whisper_pipe[size] = pipeline(
68
+ "automatic-speech-recognition",
69
+ model=model,
70
+ tokenizer=processor.tokenizer,
71
+ feature_extractor=processor.feature_extractor,
72
+ chunk_length_s=30,
73
+ device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps")
74
+ )
75
+ print(f"[✓] Whisper-{size} loaded successfully.")
76
+ except Exception as e:
77
+ print(f"[!] Error loading Whisper-{size}: {e}")
78
+ # Try without local_files_only as fallback
79
+ self.whisper_pipe[size] = pipeline(
80
+ "automatic-speech-recognition",
81
+ model=model_id,
82
+ cache_dir=self.cache_dir,
83
+ chunk_length_s=30,
84
+ device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps")
85
+ )
86
  return self.whisper_pipe[size]
87
 
88
  def get_nllb(self):
 
90
  if self.nllb_model is None:
91
  model_id = "facebook/nllb-200-distilled-600M"
92
  print(f"[*] Loading NLLB-200 translation model from {self.cache_dir} on {self.device}...")
93
+ try:
94
+ self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
95
+ self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device)
96
+ print("[✓] NLLB-200 loaded successfully.")
97
+ except Exception as e:
98
+ print(f"[!] Error loading NLLB-200: {e}")
99
+ self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
100
+ self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
101
  return self.nllb_model, self.nllb_tokenizer
102
 
103
  def get_tts(self, lang):
 
113
  raise ValueError(f"Unsupported TTS language: {lang}")
114
 
115
  print(f"[*] Loading TTS model for {lang} ({model_id}) on {self.device}...")
116
+ try:
117
+ self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
118
+ self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device)
119
+ print(f"[✓] TTS model for {lang} loaded successfully.")
120
+ except Exception as e:
121
+ print(f"[!] Error loading TTS for {lang}: {e}")
122
+ self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
123
+ self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
124
 
125
  return self.tts_models[lang], self.tts_tokenizers[lang]
126
 
 
160
  stride_length_s=5,
161
  generate_kwargs=gen_kwargs
162
  )
163
+ self._clear_memory()
164
 
165
  # Extract segments from chunks
166
  chunks = result.get("chunks", [])
 
232
  )
233
 
234
  translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
235
+ self._clear_memory()
236
  return translated_text
237
 
238
  def translate_batch(self, texts, src_lang, tgt_lang):
 
292
  )
293
 
294
  translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
295
+ self._clear_memory()
296
 
297
  # Map back to full results list
298
  for i, idx in enumerate(non_empty_indices):
 
343
 
344
  # MMS-TTS models output sample rate is 16000Hz
345
  sf.write(output_path, waveform_numpy, samplerate=16000)
346
+ self._clear_memory()
347
  print(f"[✓] TTS audio written to: {output_path}")
348
  return output_path