Reyad-Ahmmed commited on
Commit
9f20041
·
verified ·
1 Parent(s): 5add132

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -3
handler.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import time
7
 
8
  import os
 
9
  #model_dir2 = os.path.abspath("json_extraction_all")
10
  model_dir2 = "Reyad-Ahmmed/getvars-generic"
11
 
@@ -23,8 +24,23 @@ class EndpointHandler:
23
  """
24
  model_name = model_dir2 #"./json_extraction_all" # Pretrained model for sentiment analysis
25
  self.tokenizer = T5Tokenizer.from_pretrained(model_name)
26
- self.model = T5ForConditionalGeneration.from_pretrained(model_name)
27
- self.model.eval() # Set model to evaluation mode (no training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  print(f"Loaded model: {model_name}")
29
 
30
  def __call__(self, inputs):
@@ -53,7 +69,7 @@ class EndpointHandler:
53
  start_time = time.time()
54
 
55
  # Perform inference
56
- with torch.no_grad():
57
  output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)
58
 
59
  json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
6
  import time
7
 
8
  import os
9
+ from transformers import BitsAndBytesConfig
10
  #model_dir2 = os.path.abspath("json_extraction_all")
11
  model_dir2 = "Reyad-Ahmmed/getvars-generic"
12
 
 
24
  """
25
  model_name = model_dir2 #"./json_extraction_all" # Pretrained model for sentiment analysis
26
  self.tokenizer = T5Tokenizer.from_pretrained(model_name)
27
+
28
+ #self.model = T5ForConditionalGeneration.from_pretrained(model_name)
29
+ #self.model.eval() # Set model to evaluation mode (no training)
30
+
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_compute_dtype=torch.float16, # Match input dtype for faster inference
34
+ bnb_4bit_use_double_quant=True # Optional: Improves quantization efficiency
35
+ )
36
+
37
+ # Load quantized model
38
+ self.model = T5ForConditionalGeneration.from_pretrained(
39
+ model_output_path,
40
+ quantization_config=quantization_config,
41
+ device_map="auto" # Automatically uses GPU if available
42
+ )
43
+
44
  print(f"Loaded model: {model_name}")
45
 
46
  def __call__(self, inputs):
 
69
  start_time = time.time()
70
 
71
  # Perform inference
72
+ with torch.inference_mode():
73
  output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)
74
 
75
  json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)