Peter Mutwiri commited on
Commit
4eed1ee
·
1 Parent(s): 305eb68

fix: lazy load Mistral-7B for fast startup

Browse files
Files changed (1) hide show
  1. app/service/llm_service.py +26 -13
app/service/llm_service.py CHANGED
@@ -2,59 +2,72 @@
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from app.deps import HF_API_TOKEN
 
 
 
5
 
6
  class LocalLLMService:
7
  def __init__(self):
8
- # FREE, permissive license, fits in T4 GPU
9
  self.model_id = "mistralai/Mistral-7B-Instruct-v0.3"
 
 
 
 
 
 
 
 
10
 
11
- self.tokenizer = AutoTokenizer.from_pretrained(
 
 
12
  self.model_id,
13
  token=HF_API_TOKEN,
14
  trust_remote_code=True
15
  )
16
- self.tokenizer.pad_token = self.tokenizer.eos_token
17
 
18
- # Load to GPU automatically
19
- self.model = AutoModelForCausalLM.from_pretrained(
20
  self.model_id,
21
  token=HF_API_TOKEN,
22
  torch_dtype=torch.float16,
23
  device_map="auto"
24
  )
25
 
26
- self.pipe = pipeline(
27
  "text-generation",
28
- model=self.model,
29
- tokenizer=self.tokenizer,
30
  device_map="auto"
31
  )
 
32
 
33
  def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
34
- """Generate text using local model"""
 
 
35
  messages = [
36
  {"role": "system", "content": "You are a data analytics assistant. Respond with valid JSON only."},
37
  {"role": "user", "content": prompt}
38
  ]
39
 
40
- formatted_prompt = self.tokenizer.apply_chat_template(
41
  messages,
42
  tokenize=False,
43
  add_generation_prompt=True
44
  )
45
 
46
- outputs = self.pipe(
47
  formatted_prompt,
48
  max_new_tokens=max_tokens,
49
  temperature=temperature,
50
  do_sample=True
51
  )
52
 
53
- # Extract response after [/INST]
54
  response = outputs[0]["generated_text"]
55
  if "[/INST]" in response:
56
  return response.split("[/INST]")[-1].strip()
57
  return response.strip()
58
 
59
- # Singleton instance
60
  llm_service = LocalLLMService()
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from app.deps import HF_API_TOKEN
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
 
9
  class LocalLLMService:
10
  def __init__(self):
 
11
  self.model_id = "mistralai/Mistral-7B-Instruct-v0.3"
12
+ self._model = None
13
+ self._tokenizer = None
14
+ self._pipe = None
15
+
16
+ def _load_model(self):
17
+ """Lazy load model on first use - cached by HF hub"""
18
+ if self._model is not None:
19
+ return # Already loaded
20
 
21
+ logger.info(f"🤖 Loading LLM: {self.model_id}...")
22
+
23
+ self._tokenizer = AutoTokenizer.from_pretrained(
24
  self.model_id,
25
  token=HF_API_TOKEN,
26
  trust_remote_code=True
27
  )
28
+ self._tokenizer.pad_token = self._tokenizer.eos_token
29
 
30
+ self._model = AutoModelForCausalLM.from_pretrained(
 
31
  self.model_id,
32
  token=HF_API_TOKEN,
33
  torch_dtype=torch.float16,
34
  device_map="auto"
35
  )
36
 
37
+ self._pipe = pipeline(
38
  "text-generation",
39
+ model=self._model,
40
+ tokenizer=self._tokenizer,
41
  device_map="auto"
42
  )
43
+ logger.info("✅ LLM loaded successfully")
44
 
45
  def generate(self, prompt: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
46
+ """Generate text (triggers model load on first call)"""
47
+ self._load_model() # Lazy load
48
+
49
  messages = [
50
  {"role": "system", "content": "You are a data analytics assistant. Respond with valid JSON only."},
51
  {"role": "user", "content": prompt}
52
  ]
53
 
54
+ formatted_prompt = self._tokenizer.apply_chat_template(
55
  messages,
56
  tokenize=False,
57
  add_generation_prompt=True
58
  )
59
 
60
+ outputs = self._pipe(
61
  formatted_prompt,
62
  max_new_tokens=max_tokens,
63
  temperature=temperature,
64
  do_sample=True
65
  )
66
 
 
67
  response = outputs[0]["generated_text"]
68
  if "[/INST]" in response:
69
  return response.split("[/INST]")[-1].strip()
70
  return response.strip()
71
 
72
+ # Singleton instance (lightweight at import time)
73
  llm_service = LocalLLMService()