onisj commited on
Commit
b61bfe4
·
verified ·
1 Parent(s): 8fe6107

handler created

Browse files
Files changed (1) hide show
  1. handler.py +79 -0
handler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
3
+ from sentence_transformers import SentenceTransformer
4
+ import torch
5
+ import os
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ self.path = path
10
+ self.task = self._determine_task()
11
+ if self.task == "text-generation":
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ path,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
15
+ )
16
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
17
+ self.pipeline = pipeline(
18
+ "text-generation",
19
+ model=self.model,
20
+ tokenizer=self.tokenizer,
21
+ device=0 if torch.cuda.is_available() else -1
22
+ )
23
+ elif self.task == "text-classification":
24
+ self.model = AutoModelForSequenceClassification.from_pretrained(
25
+ path,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
+ )
28
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
29
+ self.pipeline = pipeline(
30
+ "text-classification",
31
+ model=self.model,
32
+ tokenizer=self.tokenizer,
33
+ device=0 if torch.cuda.is_available() else -1
34
+ )
35
+ elif self.task == "sentence-embedding":
36
+ self.model = SentenceTransformer(path)
37
+ else:
38
+ raise ValueError(f"Unsupported task: {self.task} for model at {path}")
39
+
40
+ def _determine_task(self):
41
+ config_path = os.path.join(self.path, "config.json")
42
+ if not os.path.exists(config_path):
43
+ raise ValueError(f"config.json not found in {self.path}")
44
+
45
+ config = AutoConfig.from_pretrained(self.path)
46
+ model_type = config.model_type if hasattr(config, "model_type") else None
47
+
48
+ text_generation_types = ["gpt2"]
49
+ text_classification_types = ["bert", "distilbert", "roberta"]
50
+ embedding_types = ["bert"]
51
+
52
+ model_name = self.path.split("/")[-1].lower()
53
+ if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2", "emotion_model"]:
54
+ return "text-generation"
55
+ elif model_type in text_classification_types or model_name in ["emotion_classifier", "intent_classifier", "intent_fallback"]:
56
+ return "text-classification"
57
+ elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path):
58
+ return "sentence-embedding"
59
+ raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}")
60
+
61
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
62
+ inputs = data.get("inputs", "")
63
+ parameters = data.get("parameters", None)
64
+ if not inputs:
65
+ return [{"error": "No inputs provided"}]
66
+
67
+ try:
68
+ if self.task == "text-generation":
69
+ result = self.pipeline(inputs, max_length=50, num_return_sequences=1, **(parameters or {}))
70
+ return [{"generated_text": item["generated_text"]} for item in result]
71
+ elif self.task == "text-classification":
72
+ result = self.pipeline(inputs, return_all_scores=True, **(parameters or {}))
73
+ return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist]
74
+ elif self.task == "sentence-embedding":
75
+ embeddings = self.model.encode(inputs)
76
+ return [{"embeddings": embeddings.tolist()}]
77
+ return [{"error": f"Unsupported task: {self.task}"}]
78
+ except Exception as e:
79
+ return [{"error": f"Inference failed: {str(e)}"}]