nexusbert commited on
Commit
3b4c0b9
·
1 Parent(s): 9faaac1

push atlas first

Browse files
Files changed (2) hide show
  1. Dockerfile +4 -0
  2. app.py +7 -4
Dockerfile CHANGED
@@ -42,6 +42,10 @@ RUN python -c "from transformers import pipeline; pipeline('text-to-speech', mod
42
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
43
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
44
 
 
 
 
 
45
  # Copy project files
46
  COPY . .
47
 
 
42
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-eng')" \
43
  && python -c "from transformers import pipeline; pipeline('text-to-speech', model='facebook/mms-tts-yor')"
44
 
45
+ # Pre-load N-ATLaS model during build
46
+ RUN python -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='NCAIR1/N-ATLaS')" \
47
+ && python -c "from transformers import AutoTokenizer, AutoModelForCausalLM; import torch; tokenizer = AutoTokenizer.from_pretrained('NCAIR1/N-ATLaS'); model = AutoModelForCausalLM.from_pretrained('NCAIR1/N-ATLaS', torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map='auto' if torch.cuda.is_available() else None, trust_remote_code=True, low_cpu_mem_usage=True, use_cache=True); print('N-ATLaS model loaded successfully')"
48
+
49
  # Copy project files
50
  COPY . .
51
 
app.py CHANGED
@@ -59,6 +59,10 @@ def load_models():
59
  logger.warning("Please set HF_TOKEN environment variable to access restricted models.")
60
  else:
61
  logger.info("HF_TOKEN is set and ready for authenticated model access.")
 
 
 
 
62
  logger.info("Loading TTS models...")
63
  try:
64
  tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device)
@@ -82,8 +86,6 @@ def load_models():
82
  tts_ig = None
83
  logger.info("Igbo TTS model disabled - will return text responses for Igbo language")
84
 
85
- logger.info("N-ATLaS language identification model will be lazy-loaded on first use")
86
-
87
  logger.info("Deferred ASR model loads: will lazy-load per language on first use")
88
 
89
  def _get_asr(lang_code: str):
@@ -199,6 +201,7 @@ IGBO_WORDS = [
199
  def _load_natlas():
200
  global natlas_tokenizer, natlas_model
201
  if natlas_tokenizer is not None and natlas_model is not None:
 
202
  return True
203
 
204
  hf_token = os.getenv("HF_TOKEN")
@@ -210,8 +213,8 @@ def _load_natlas():
210
  return False
211
 
212
  try:
213
- logger.info("Lazy-loading N-ATLaS language identification model...")
214
- logger.info("This may take a few minutes as the model loads its shards...")
215
 
216
  natlas_tokenizer = AutoTokenizer.from_pretrained("NCAIR1/N-ATLaS", token=hf_token)
217
  natlas_model = AutoModelForCausalLM.from_pretrained(
 
59
  logger.warning("Please set HF_TOKEN environment variable to access restricted models.")
60
  else:
61
  logger.info("HF_TOKEN is set and ready for authenticated model access.")
62
+
63
+ logger.info("Loading N-ATLaS language identification model...")
64
+ _load_natlas()
65
+
66
  logger.info("Loading TTS models...")
67
  try:
68
  tts_ha = pipeline("text-to-speech", model="facebook/mms-tts-hau", device=device)
 
86
  tts_ig = None
87
  logger.info("Igbo TTS model disabled - will return text responses for Igbo language")
88
 
 
 
89
  logger.info("Deferred ASR model loads: will lazy-load per language on first use")
90
 
91
  def _get_asr(lang_code: str):
 
201
  def _load_natlas():
202
  global natlas_tokenizer, natlas_model
203
  if natlas_tokenizer is not None and natlas_model is not None:
204
+ logger.info("N-ATLaS model already loaded")
205
  return True
206
 
207
  hf_token = os.getenv("HF_TOKEN")
 
213
  return False
214
 
215
  try:
216
+ logger.info("Loading N-ATLaS language identification model...")
217
+ logger.info("Model files are pre-cached from Docker build, loading should be faster...")
218
 
219
  natlas_tokenizer = AutoTokenizer.from_pretrained("NCAIR1/N-ATLaS", token=hf_token)
220
  natlas_model = AutoModelForCausalLM.from_pretrained(