Sandei commited on
Commit
0a70e53
·
1 Parent(s): d5dcd77

quantization

Browse files
Files changed (1) hide show
  1. service/llm_service.py +42 -16
service/llm_service.py CHANGED
@@ -1,34 +1,60 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
 
4
  class LLMService:
5
  def __init__(self):
 
 
 
6
  self.tokenizer = AutoTokenizer.from_pretrained(
7
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
8
  )
9
- self.model = AutoModelForCausalLM.from_pretrained(
10
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
11
- torch_dtype=torch.float16,
12
- device_map="auto"
 
 
 
 
 
 
 
 
13
  )
14
 
 
 
 
 
 
15
  def generate(self, prompt: str) -> str:
16
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
17
-
18
- output = self.model.generate(
19
- **inputs,
20
- max_new_tokens=256,
21
- do_sample=True,
22
- temperature=0.7,
23
- top_p=0.9,
24
- eos_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
25
  )
26
 
27
- text = self.tokenizer.decode(output[0], skip_special_tokens=False)
28
  return self._clean(text)
29
 
30
  def _clean(self, text: str) -> str:
31
- # Extract content AFTER <|assistant|>
32
  if "<|assistant|>" in text:
33
  text = text.split("<|assistant|>")[-1]
34
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+
5
  class LLMService:
6
  def __init__(self):
7
+ self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
+
9
+ # Tokenizer
10
  self.tokenizer = AutoTokenizer.from_pretrained(
11
+ self.model_name,
12
+ use_fast=True
13
  )
14
+
15
+ # Load model in FP32 on CPU
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ self.model_name,
18
+ torch_dtype=torch.float32
19
+ )
20
+
21
+ # 🔥 CPU INT8 dynamic quantization
22
+ self.model = torch.quantization.quantize_dynamic(
23
+ model,
24
+ {torch.nn.Linear},
25
+ dtype=torch.qint8
26
  )
27
 
28
+ self.model.eval()
29
+
30
+ # Optional sanity check
31
+ print("LLM loaded with dtype:", next(self.model.parameters()).dtype)
32
+
33
  def generate(self, prompt: str) -> str:
34
+ inputs = self.tokenizer(
35
+ prompt,
36
+ return_tensors="pt",
37
+ truncation=True,
38
+ max_length=1024
39
+ )
40
+
41
+ with torch.no_grad():
42
+ output = self.model.generate(
43
+ **inputs,
44
+ max_new_tokens=120, # ⬅️ faster + enough
45
+ do_sample=False, # ⬅️ HUGE speed win
46
+ eos_token_id=self.tokenizer.eos_token_id
47
+ )
48
+
49
+ text = self.tokenizer.decode(
50
+ output[0],
51
+ skip_special_tokens=False
52
  )
53
 
 
54
  return self._clean(text)
55
 
56
  def _clean(self, text: str) -> str:
57
+ # Extract content AFTER <|assistant|>
58
  if "<|assistant|>" in text:
59
  text = text.split("<|assistant|>")[-1]
60