Peter Mutwiri commited on
Commit
698a188
Β·
1 Parent(s): f77f60f

refactored load ml service

Browse files
Files changed (2) hide show
  1. app/main.py +7 -4
  2. app/service/llm_service.py +48 -22
app/main.py CHANGED
@@ -24,7 +24,7 @@ from app.deps import get_current_user, rate_limit_org, verify_api_key, check_all
24
  from app.tasks.analytics_worker import redis_listener, trigger_kpi_computation
25
  from app.service.vector_service import cleanup_expired_vectors
26
  from app.routers import health, datasources, reports, flags, scheduler, run, socket, analytics_stream,ai_query,schema
27
-
28
  # ─── Logger Configuration ───────────────────────────────────────────────────────
29
  logging.basicConfig(
30
  level=logging.INFO,
@@ -107,9 +107,12 @@ async def lifespan(app: FastAPI):
107
  logger.info("⏰ Starting KPI refresh scheduler...")
108
  asyncio.create_task(continuous_kpi_refresh(), name="kpi_scheduler")
109
  # Now load LLM service - it will use persistent cache
110
- from app.service.llm_service import LocalLLMService
111
- logger.info("πŸ€– LLM service initialized (will use persistent cache)")
112
-
 
 
 
113
  yield
114
 
115
  # ─── Shutdown ──────────────────────────────────────────────────────────────
 
24
  from app.tasks.analytics_worker import redis_listener, trigger_kpi_computation
25
  from app.service.vector_service import cleanup_expired_vectors
26
  from app.routers import health, datasources, reports, flags, scheduler, run, socket, analytics_stream,ai_query,schema
27
+ from app.service.llm_service import load_llm_service
28
  # ─── Logger Configuration ───────────────────────────────────────────────────────
29
  logging.basicConfig(
30
  level=logging.INFO,
 
107
  logger.info("⏰ Starting KPI refresh scheduler...")
108
  asyncio.create_task(continuous_kpi_refresh(), name="kpi_scheduler")
109
  # Now load LLM service - it will use persistent cache
110
+ try:
111
+ load_llm_service() # Starts background loading
112
+ logger.info("πŸ€– LLM service loading in background...")
113
+ except Exception as e:
114
+ logger.error(f"❌ LLM load failed: {e}")
115
+ # Continue anyway - LLM is optional for some features
116
  yield
117
 
118
  # ─── Shutdown ──────────────────────────────────────────────────────────────
app/service/llm_service.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  from threading import Thread, Lock
7
  import time
8
  import json
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -21,33 +22,36 @@ class LocalLLMService:
21
  self._lock = Lock()
22
 
23
  # βœ… Use persistent cache
24
- cache_dir = "/data/hf_cache"
25
- os.makedirs(cache_dir, exist_ok=True)
26
 
27
- logger.info("πŸš€ Starting background LLM load...")
28
- Thread(target=self._load_model_background, daemon=True).start()
29
 
30
- def _load_model_background(self):
31
- """Load model in background thread with persistent cache"""
32
  with self._lock:
33
  if self._is_loading or self._is_loaded:
 
34
  return
 
35
  self._is_loading = True
36
-
 
 
 
 
 
37
  try:
38
  logger.info(f"πŸ€– [BACKGROUND] Loading LLM: {self.model_id}...")
39
 
40
- # βœ… Use persistent cache directory
41
- cache_dir = "/data/hf_cache"
42
-
43
  # Phi-3 tokenizer
44
  self._tokenizer = AutoTokenizer.from_pretrained(
45
  self.model_id,
46
  token=HF_API_TOKEN,
47
  trust_remote_code=True,
48
- cache_dir=cache_dir # βœ… Persistent cache
49
  )
50
- # .
51
  self._tokenizer.pad_token = self._tokenizer.eos_token
52
 
53
  # Phi-3 model - OPTIMIZED for speed
@@ -58,7 +62,8 @@ class LocalLLMService:
58
  device_map="auto",
59
  low_cpu_mem_usage=True,
60
  trust_remote_code=True,
61
- attn_implementation="eager" # βœ… No flash-attn warnings
 
62
  )
63
 
64
  # βœ… FASTER pipeline settings
@@ -69,8 +74,8 @@ class LocalLLMService:
69
  device_map="auto",
70
  torch_dtype=torch.float16,
71
  trust_remote_code=True,
72
- # βœ… SPEED UP: Use pad_token_id
73
- pad_token_id=self._tokenizer.eos_token_id
74
  )
75
 
76
  with self._lock:
@@ -105,9 +110,7 @@ class LocalLLMService:
105
  raise TimeoutError("LLM loading in progress")
106
 
107
  # βœ… Phi-3 prompt format (TESTED to work)
108
- messages = [
109
- {"role": "user", "content": prompt}
110
- ]
111
 
112
  formatted_prompt = self._tokenizer.apply_chat_template(
113
  messages,
@@ -122,7 +125,7 @@ class LocalLLMService:
122
  temperature=temperature,
123
  do_sample=False,
124
  pad_token_id=self._tokenizer.eos_token_id,
125
- return_full_text=False # βœ… Only return generated text
126
  )
127
 
128
  # βœ… SAFE extraction
@@ -136,12 +139,35 @@ class LocalLLMService:
136
 
137
  # βœ… VALIDATE JSON before returning
138
  try:
139
- json.loads(response_text) # Test parse
140
  logger.info(f"[llm] Valid JSON generated: {response_text[:50]}...")
141
  return response_text
142
  except json.JSONDecodeError:
143
  logger.error(f"[llm] Invalid JSON from LLM: {response_text}")
144
  raise ValueError(f"LLM returned invalid JSON: {response_text}")
145
 
146
- # Singleton
147
- llm_service = LocalLLMService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from threading import Thread, Lock
7
  import time
8
  import json
9
+ import os
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
22
  self._lock = Lock()
23
 
24
  # βœ… Use persistent cache
25
+ self.cache_dir = "/data/hf_cache"
26
+ os.makedirs(self.cache_dir, exist_ok=True)
27
 
28
+ # ❌ DON'T start loading here - truly lazy
29
+ self._load_thread = None
30
 
31
+ def load(self):
32
+ """Explicitly start loading the model - call this ONLY after build is verified"""
33
  with self._lock:
34
  if self._is_loading or self._is_loaded:
35
+ logger.info("Model already loading or loaded")
36
  return
37
+
38
  self._is_loading = True
39
+ logger.info("πŸš€ Starting LLM load...")
40
+ self._load_thread = Thread(target=self._load_model_background, daemon=True)
41
+ self._load_thread.start()
42
+
43
+ def _load_model_background(self):
44
+ """Load model in background thread with persistent cache"""
45
  try:
46
  logger.info(f"πŸ€– [BACKGROUND] Loading LLM: {self.model_id}...")
47
 
 
 
 
48
  # Phi-3 tokenizer
49
  self._tokenizer = AutoTokenizer.from_pretrained(
50
  self.model_id,
51
  token=HF_API_TOKEN,
52
  trust_remote_code=True,
53
+ cache_dir=self.cache_dir
54
  )
 
55
  self._tokenizer.pad_token = self._tokenizer.eos_token
56
 
57
  # Phi-3 model - OPTIMIZED for speed
 
62
  device_map="auto",
63
  low_cpu_mem_usage=True,
64
  trust_remote_code=True,
65
+ attn_implementation="eager",
66
+ cache_dir=self.cache_dir # βœ… Persistent cache
67
  )
68
 
69
  # βœ… FASTER pipeline settings
 
74
  device_map="auto",
75
  torch_dtype=torch.float16,
76
  trust_remote_code=True,
77
+ pad_token_id=self._tokenizer.eos_token_id,
78
+ cache_dir=self.cache_dir
79
  )
80
 
81
  with self._lock:
 
110
  raise TimeoutError("LLM loading in progress")
111
 
112
  # βœ… Phi-3 prompt format (TESTED to work)
113
+ messages = [{"role": "user", "content": prompt}]
 
 
114
 
115
  formatted_prompt = self._tokenizer.apply_chat_template(
116
  messages,
 
125
  temperature=temperature,
126
  do_sample=False,
127
  pad_token_id=self._tokenizer.eos_token_id,
128
+ return_full_text=False
129
  )
130
 
131
  # βœ… SAFE extraction
 
139
 
140
  # βœ… VALIDATE JSON before returning
141
  try:
142
+ json.loads(response_text)
143
  logger.info(f"[llm] Valid JSON generated: {response_text[:50]}...")
144
  return response_text
145
  except json.JSONDecodeError:
146
  logger.error(f"[llm] Invalid JSON from LLM: {response_text}")
147
  raise ValueError(f"LLM returned invalid JSON: {response_text}")
148
 
149
+
150
+ # βœ… LAZY singleton creation - instance created ONLY when first requested
151
+ _llm_service_instance = None
152
+
153
+ def get_llm_service():
154
+ """Get or create the singleton LLM service (lazy initialization)"""
155
+ global _llm_service_instance
156
+
157
+ if _llm_service_instance is None:
158
+ logger.info("πŸ†• Creating LLM service instance (lazy)")
159
+ _llm_service_instance = LocalLLMService()
160
+
161
+ return _llm_service_instance
162
+
163
+
164
+ def load_llm_service():
165
+ """
166
+ Explicitly load the LLM service.
167
+ Call this AFTER startup sequence to ensure build is successful.
168
+ """
169
+ service = get_llm_service()
170
+ if not service.is_loaded and not service.is_loading:
171
+ service.load()
172
+ logger.info("πŸ€– LLM service loading triggered")
173
+ return service