Sandei commited on
Commit
ad0633b
·
1 Parent(s): fbcd161

<s> response issue

Browse files
Files changed (1) hide show
  1. service/llm_service.py +18 -9
service/llm_service.py CHANGED
@@ -9,11 +9,13 @@ class LLMService:
9
  # torch.set_num_threads(...)
10
  # torch.set_num_interop_threads(...)
11
 
 
12
  self.tokenizer = AutoTokenizer.from_pretrained(
13
  self.model_name,
14
  use_fast=True
15
  )
16
 
 
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
  self.model_name,
19
  torch_dtype=torch.float32
@@ -23,30 +25,37 @@ class LLMService:
23
 
24
  print("LLM loaded | dtype:", next(self.model.parameters()).dtype)
25
 
26
- def generate(self, prompt: str) -> str:
 
 
 
27
  inputs = self.tokenizer(
28
  prompt,
29
  return_tensors="pt",
30
  truncation=True,
31
- max_length=640 # ⬅️ important
32
  )
33
 
34
  with torch.no_grad():
35
  output = self.model.generate(
36
  **inputs,
37
- max_new_tokens=96, # ⬅️ enough for helpdesk
38
- do_sample=False,
39
- eos_token_id=self.tokenizer.eos_token_id
 
40
  )
41
 
42
- text = self.tokenizer.decode(
43
- output[0],
44
- skip_special_tokens=False
45
- )
46
 
47
  return self._clean(text)
48
 
49
  def _clean(self, text: str) -> str:
 
 
 
 
 
50
  if "<|assistant|>" in text:
51
  text = text.split("<|assistant|>")[-1]
52
 
 
9
  # torch.set_num_threads(...)
10
  # torch.set_num_interop_threads(...)
11
 
12
+ # Tokenizer
13
  self.tokenizer = AutoTokenizer.from_pretrained(
14
  self.model_name,
15
  use_fast=True
16
  )
17
 
18
+ # Model in FP32 on CPU
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
  self.model_name,
21
  torch_dtype=torch.float32
 
25
 
26
  print("LLM loaded | dtype:", next(self.model.parameters()).dtype)
27
 
28
+ def generate(self, user_query: str) -> str:
29
+ # Wrap user input with role tokens for TinyLlama
30
+ prompt = f"<|user|>{user_query}<|assistant|>"
31
+
32
  inputs = self.tokenizer(
33
  prompt,
34
  return_tensors="pt",
35
  truncation=True,
36
+ max_length=640 # maintain context window
37
  )
38
 
39
  with torch.no_grad():
40
  output = self.model.generate(
41
  **inputs,
42
+ max_new_tokens=120, # slightly higher for complete answer
43
+ do_sample=False, # deterministic + faster
44
+ eos_token_id=self.tokenizer.eos_token_id,
45
+ use_cache=True
46
  )
47
 
48
+ # Decode and remove special tokens
49
+ text = self.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
50
 
51
  return self._clean(text)
52
 
53
  def _clean(self, text: str) -> str:
54
+ """
55
+ Maintains your previous cleaning logic:
56
+ - Extract after <|assistant|>
57
+ - Stop at <|system|> or <|user|>
58
+ """
59
  if "<|assistant|>" in text:
60
  text = text.split("<|assistant|>")[-1]
61