getvars-generic / handler.py
Reyad-Ahmmed's picture
Update handler.py
cacb98b verified
import json
from datetime import datetime
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import TrainingArguments, Trainer
import torch
import time
import os
from transformers import BitsAndBytesConfig
#model_dir2 = os.path.abspath("json_extraction_all")
model_dir2 = "Reyad-Ahmmed/getvars-generic"
class EndpointHandler:
def __init__(self, model_dir):
self.model_dir = model_dir2
self.model = None
self.tokenizer = None
self.load() # Ensure model loads on initialization
def load(self):
"""
Load a simple DistilBERT model for text classification.
"""
model_name = model_dir2 #"./json_extraction_all" # Pretrained model for sentiment analysis
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
#self.model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
# Load model in float16 for faster inference
self.model = T5ForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use float16 for faster computation
device_map="auto" # Automatically uses GPU if available
)
self.model.eval() # Set model to evaluation mode (no training)
# Check if the model is on GPU
device = next(self.model.parameters()).device
print(f"Model is loaded on: {device}") # Should print 'cuda:0' if on GPU
#self.quantization_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_compute_dtype=torch.float16, # Match input dtype for faster inference
# bnb_4bit_use_double_quant=True # Optional: Improves quantization efficiency
#)
# Load quantized model
#self.model = T5ForConditionalGeneration.from_pretrained(
# model_name,
# quantization_config=self.quantization_config,
# device_map="auto" # Automatically uses GPU if available
#)
print(f"Loaded model: {model_name}")
def __call__(self, inputs):
"""
Process user input and classify the text using DistilBERT.
"""
try:
if self.tokenizer is None or self.model is None:
raise ValueError("Model and tokenizer were not loaded properly.")
# Handle different input formats
if isinstance(inputs, list) and len(inputs) > 0:
user_text = inputs[0]
elif isinstance(inputs, dict) and "inputs" in inputs:
user_text = inputs["inputs"]
else:
return {"error": "Invalid input format. Expected {'inputs': 'your text'} or ['your text']."}
# Generate timestamp
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Tokenize input text
input_ids = self.tokenizer(user_text, return_tensors="pt").input_ids.to("cuda")
# Measure inference time
start_time = time.time()
# Perform inference
with torch.inference_mode():
output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)
json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
end_time = time.time()
inference_time = end_time - start_time # Calculate time taken
# Print inference time
print(f"Inference Time: {inference_time:.4f} seconds")
# return json.loads(json_output)
try:
return json.loads(json_output)
except:
return json_output
except Exception as e:
return {"error": f"Unexpected error: {str(e)}"}