|
|
import json |
|
|
from datetime import datetime |
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
|
from transformers import TrainingArguments, Trainer |
|
|
import torch |
|
|
import time |
|
|
|
|
|
import os |
|
|
from transformers import BitsAndBytesConfig |
|
|
|
|
|
model_dir2 = "Reyad-Ahmmed/getvars-generic" |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
self.model_dir = model_dir2 |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.load() |
|
|
|
|
|
def load(self): |
|
|
""" |
|
|
Load a simple DistilBERT model for text classification. |
|
|
""" |
|
|
model_name = model_dir2 |
|
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
self.model = T5ForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
device = next(self.model.parameters()).device |
|
|
print(f"Model is loaded on: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loaded model: {model_name}") |
|
|
|
|
|
def __call__(self, inputs): |
|
|
""" |
|
|
Process user input and classify the text using DistilBERT. |
|
|
""" |
|
|
try: |
|
|
if self.tokenizer is None or self.model is None: |
|
|
raise ValueError("Model and tokenizer were not loaded properly.") |
|
|
|
|
|
|
|
|
if isinstance(inputs, list) and len(inputs) > 0: |
|
|
user_text = inputs[0] |
|
|
elif isinstance(inputs, dict) and "inputs" in inputs: |
|
|
user_text = inputs["inputs"] |
|
|
else: |
|
|
return {"error": "Invalid input format. Expected {'inputs': 'your text'} or ['your text']."} |
|
|
|
|
|
|
|
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer(user_text, return_tensors="pt").input_ids.to("cuda") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3) |
|
|
|
|
|
json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
end_time = time.time() |
|
|
inference_time = end_time - start_time |
|
|
|
|
|
|
|
|
print(f"Inference Time: {inference_time:.4f} seconds") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
return json.loads(json_output) |
|
|
except: |
|
|
return json_output |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Unexpected error: {str(e)}"} |
|
|
|