Update handler.py
Browse files- handler.py +19 -3
handler.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch
|
|
| 6 |
import time
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 9 |
#model_dir2 = os.path.abspath("json_extraction_all")
|
| 10 |
model_dir2 = "Reyad-Ahmmed/getvars-generic"
|
| 11 |
|
|
@@ -23,8 +24,23 @@ class EndpointHandler:
|
|
| 23 |
"""
|
| 24 |
model_name = model_dir2 #"./json_extraction_all" # Pretrained model for sentiment analysis
|
| 25 |
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 26 |
-
|
| 27 |
-
self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
print(f"Loaded model: {model_name}")
|
| 29 |
|
| 30 |
def __call__(self, inputs):
|
|
@@ -53,7 +69,7 @@ class EndpointHandler:
|
|
| 53 |
start_time = time.time()
|
| 54 |
|
| 55 |
# Perform inference
|
| 56 |
-
with torch.
|
| 57 |
output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)
|
| 58 |
|
| 59 |
json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
| 6 |
import time
|
| 7 |
|
| 8 |
import os
|
| 9 |
+
from transformers import BitsAndBytesConfig
|
| 10 |
#model_dir2 = os.path.abspath("json_extraction_all")
|
| 11 |
model_dir2 = "Reyad-Ahmmed/getvars-generic"
|
| 12 |
|
|
|
|
| 24 |
"""
|
| 25 |
model_name = model_dir2 #"./json_extraction_all" # Pretrained model for sentiment analysis
|
| 26 |
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 27 |
+
|
| 28 |
+
#self.model = T5ForConditionalGeneration.from_pretrained(model_name)
|
| 29 |
+
#self.model.eval() # Set model to evaluation mode (no training)
|
| 30 |
+
|
| 31 |
+
quantization_config = BitsAndBytesConfig(
|
| 32 |
+
load_in_4bit=True,
|
| 33 |
+
bnb_4bit_compute_dtype=torch.float16, # Match input dtype for faster inference
|
| 34 |
+
bnb_4bit_use_double_quant=True # Optional: Improves quantization efficiency
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Load quantized model
|
| 38 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
| 39 |
+
model_output_path,
|
| 40 |
+
quantization_config=quantization_config,
|
| 41 |
+
device_map="auto" # Automatically uses GPU if available
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
print(f"Loaded model: {model_name}")
|
| 45 |
|
| 46 |
def __call__(self, inputs):
|
|
|
|
| 69 |
start_time = time.time()
|
| 70 |
|
| 71 |
# Perform inference
|
| 72 |
+
with torch.inference_mode():
|
| 73 |
output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)
|
| 74 |
|
| 75 |
json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|