File size: 3,814 Bytes
6577dc7
a91c42a
c3a86ed
 
bde1ce5
01d86e5
ed95945
2765e76
9f20041
40bce02
ee7b1d2
2765e76
 
0ad021c
 
2765e76
0abe65d
 
 
0ad021c
 
bde1ce5
 
 
208cb75
04f9167
9f20041
cacb98b
 
 
 
 
 
 
 
ed95945
 
262130c
 
 
 
ed95945
 
 
 
 
9f20041
 
ed95945
 
 
 
 
9f20041
bde1ce5
0ad021c
1c414e1
 
bde1ce5
1c414e1
 
0abe65d
 
 
bde1ce5
1c414e1
bde1ce5
1c414e1
 
 
 
a91c42a
1c414e1
 
a91c42a
bde1ce5
f1e6ae0
bde1ce5
8492ae4
 
 
bde1ce5
9f20041
4065b97
04f9167
4065b97
8492ae4
 
 
 
 
 
 
 
9d6ef4e
 
 
 
 
c68ae67
9700849
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import json
from datetime import datetime
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import TrainingArguments, Trainer
import torch
import time

import os
from transformers import BitsAndBytesConfig
#model_dir2 = os.path.abspath("json_extraction_all")
model_dir2 = "Reyad-Ahmmed/getvars-generic"


class EndpointHandler:
    def __init__(self, model_dir):
        self.model_dir = model_dir2
        self.model = None
        self.tokenizer = None
        self.load()  # Ensure model loads on initialization

    def load(self):
        """
        Load a simple DistilBERT model for text classification.
        """
        model_name = model_dir2 #"./json_extraction_all"  # Pretrained model for sentiment analysis
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        
        #self.model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
        # Load model in float16 for faster inference
        self.model = T5ForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use float16 for faster computation
            device_map="auto"  # Automatically uses GPU if available
        )
        
        self.model.eval()  # Set model to evaluation mode (no training)

        # Check if the model is on GPU
        device = next(self.model.parameters()).device
        print(f"Model is loaded on: {device}")  # Should print 'cuda:0' if on GPU
        
        #self.quantization_config = BitsAndBytesConfig(
        #    load_in_4bit=True,
        #    bnb_4bit_compute_dtype=torch.float16,  # Match input dtype for faster inference
        #    bnb_4bit_use_double_quant=True  # Optional: Improves quantization efficiency
        #)
    
        # Load quantized model
        #self.model = T5ForConditionalGeneration.from_pretrained(
        #    model_name,
        #    quantization_config=self.quantization_config,
        #    device_map="auto"  # Automatically uses GPU if available
        #)
        
        print(f"Loaded model: {model_name}")

    def __call__(self, inputs):
        """
        Process user input and classify the text using DistilBERT.
        """
        try:
            if self.tokenizer is None or self.model is None:
                raise ValueError("Model and tokenizer were not loaded properly.")

            # Handle different input formats
            if isinstance(inputs, list) and len(inputs) > 0:
                user_text = inputs[0]
            elif isinstance(inputs, dict) and "inputs" in inputs:
                user_text = inputs["inputs"]
            else:
                return {"error": "Invalid input format. Expected {'inputs': 'your text'} or ['your text']."}

            # Generate timestamp
            current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            # Tokenize input text
            input_ids = self.tokenizer(user_text, return_tensors="pt").input_ids.to("cuda")

            # Measure inference time
            start_time = time.time()
            
            # Perform inference
            with torch.inference_mode():
                output_ids = self.model.generate(input_ids, max_length=100, temperature=0.3)

            json_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

            end_time = time.time()
            inference_time = end_time - start_time  # Calculate time taken

            # Print inference time
            print(f"Inference Time: {inference_time:.4f} seconds")

        
            # return json.loads(json_output)
            try:
                return json.loads(json_output)
            except:
                return json_output

        except Exception as e:
            return {"error": f"Unexpected error: {str(e)}"}