therandomuser03 commited on
Commit
8d1fac5
·
1 Parent(s): 561a3db

update backend - HF

Browse files
Dockerfile CHANGED
@@ -40,6 +40,7 @@ RUN python download_models.py
40
  # 9. Environment & Port settings (7860 is HF Spaces standard)
41
  ENV PYTHONPATH=/app
42
  ENV USE_EMBEDDED_LLM=True
 
43
  EXPOSE 7860
44
 
45
  # 10. Run the app with Uvicorn
 
40
  # 9. Environment & Port settings (7860 is HF Spaces standard)
41
  ENV PYTHONPATH=/app
42
  ENV USE_EMBEDDED_LLM=True
43
+ ENV HF_HUB_OFFLINE=1
44
  EXPOSE 7860
45
 
46
  # 10. Run the app with Uvicorn
app/services/crisis_engine.py CHANGED
@@ -63,10 +63,14 @@ def initialize_crisis_classifier() -> None:
63
  global _zero_shot_pipeline, _load_error
64
  try:
65
  from transformers import pipeline as hf_pipeline
66
- logger.info("Loading crisis zero-shot classifier...")
 
 
 
 
67
  _zero_shot_pipeline = hf_pipeline(
68
  "zero-shot-classification",
69
- model="cross-encoder/nli-MiniLM2-L6-H768",
70
  device=-1, # CPU
71
  )
72
  logger.info("✅ Crisis classifier loaded.")
 
63
  global _zero_shot_pipeline, _load_error
64
  try:
65
  from transformers import pipeline as hf_pipeline
66
+ import os
67
+
68
+ local_path = os.path.join("app", "ml_assets", "crisis_model")
69
+ logger.info("Loading crisis zero-shot classifier from %s", local_path)
70
+
71
  _zero_shot_pipeline = hf_pipeline(
72
  "zero-shot-classification",
73
+ model=local_path if os.path.exists(local_path) else "cross-encoder/nli-MiniLM2-L6-H768",
74
  device=-1, # CPU
75
  )
76
  logger.info("✅ Crisis classifier loaded.")
app/services/text_emotion_engine.py CHANGED
@@ -23,10 +23,15 @@ def _load_pipeline(model_name: str) -> None:
23
  global _pipeline, _load_error
24
  try:
25
  from transformers import pipeline as hf_pipeline
26
- logger.info("Loading DistilBERT text emotion model: %s", model_name)
 
 
 
 
 
27
  _pipeline = hf_pipeline(
28
  "text-classification",
29
- model=model_name,
30
  top_k=None, # Return ALL labels
31
  truncation=True,
32
  max_length=512,
 
23
  global _pipeline, _load_error
24
  try:
25
  from transformers import pipeline as hf_pipeline
26
+ import os
27
+
28
+ # Determine local path
29
+ local_path = os.path.join("app", "ml_assets", "distilbert_model")
30
+
31
+ logger.info("Loading DistilBERT text emotion model from %s", local_path)
32
  _pipeline = hf_pipeline(
33
  "text-classification",
34
+ model=local_path if os.path.exists(local_path) else model_name,
35
  top_k=None, # Return ALL labels
36
  truncation=True,
37
  max_length=512,
download_models.py CHANGED
@@ -16,6 +16,13 @@ FACE_MODEL_PATH = os.path.join(ML_ASSETS, "emotion_model_trained.h5")
16
  MEDS_CSV_PATH = os.path.join(ML_ASSETS, "MEDICATION.csv")
17
  LLAMA_GGUF_PATH = os.path.join(ML_ASSETS, "llama-3-8b-instruct.Q4_K_M.gguf")
18
 
 
 
 
 
 
 
 
19
  def download_drive_file(file_id, output_path):
20
  if not os.path.exists(output_path):
21
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
@@ -42,6 +49,19 @@ def download_hf_model(repo_id, filename, output_path):
42
  else:
43
  print(f"✅ Found {output_path}, skipping.")
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if __name__ == "__main__":
46
  print("🚀 Starting Production Model Sync...")
47
 
@@ -53,6 +73,13 @@ if __name__ == "__main__":
53
  try:
54
  download_hf_model(LLAMA_REPO, LLAMA_FILE, LLAMA_GGUF_PATH)
55
  except Exception as e:
56
- print(f"⚠️ HF Download failed (expected on local dev if no internet): {e}")
 
 
 
 
 
 
 
57
 
58
  print("✅ All models synchronized!")
 
16
  MEDS_CSV_PATH = os.path.join(ML_ASSETS, "MEDICATION.csv")
17
  LLAMA_GGUF_PATH = os.path.join(ML_ASSETS, "llama-3-8b-instruct.Q4_K_M.gguf")
18
 
19
+ # HF Transformers (Downloaded via snapshot_download for full directory)
20
+ CRISIS_MODEL_REPO = "cross-encoder/nli-MiniLM2-L6-H768"
21
+ DISTILBERT_MODEL_REPO = "bhadresh-savani/distilbert-base-uncased-emotion"
22
+
23
+ CRISIS_MODEL_PATH = os.path.join(ML_ASSETS, "crisis_model")
24
+ DISTILBERT_MODEL_PATH = os.path.join(ML_ASSETS, "distilbert_model")
25
+
26
  def download_drive_file(file_id, output_path):
27
  if not os.path.exists(output_path):
28
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
 
49
  else:
50
  print(f"✅ Found {output_path}, skipping.")
51
 
52
+ def download_hf_directory(repo_id, output_dir):
53
+ from huggingface_hub import snapshot_download
54
+ if not os.path.exists(output_dir) or not os.listdir(output_dir):
55
+ print(f"⬇️ Downloading HF repo: {repo_id} to {output_dir}...")
56
+ snapshot_download(
57
+ repo_id=repo_id,
58
+ local_dir=output_dir,
59
+ local_dir_use_symlinks=False,
60
+ ignore_patterns=["*.msgpack", "*.h5", "*.ot", "rust_model.ot"] # save space, only PyTorch/Safetensors needed
61
+ )
62
+ else:
63
+ print(f"✅ Found {output_dir}, skipping.")
64
+
65
  if __name__ == "__main__":
66
  print("🚀 Starting Production Model Sync...")
67
 
 
73
  try:
74
  download_hf_model(LLAMA_REPO, LLAMA_FILE, LLAMA_GGUF_PATH)
75
  except Exception as e:
76
+ print(f"⚠️ HF LLaMA Download failed (expected on local dev if no internet): {e}")
77
+
78
+ # 3. HF Transformers Pipeline Models
79
+ try:
80
+ download_hf_directory(CRISIS_MODEL_REPO, CRISIS_MODEL_PATH)
81
+ download_hf_directory(DISTILBERT_MODEL_REPO, DISTILBERT_MODEL_PATH)
82
+ except Exception as e:
83
+ print(f"⚠️ HF Transformers Download failed: {e}")
84
 
85
  print("✅ All models synchronized!")