Update handler.py
Browse files- handler.py +23 -28
handler.py
CHANGED
|
@@ -1,42 +1,37 @@
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
-
from
|
| 3 |
import os
|
| 4 |
|
| 5 |
-
|
| 6 |
-
class EndpointHandler():
|
| 7 |
def __init__(self, path=""):
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
config_path = os.path.join(path, "gliner_config.json")
|
| 11 |
-
if not os.path.exists(config_path):
|
| 12 |
-
raise FileNotFoundError(f"Custom configuration file not found at {config_path}")
|
| 13 |
-
|
| 14 |
-
# Load the custom configuration
|
| 15 |
-
config = AutoConfig.from_pretrained(config_path)
|
| 16 |
-
|
| 17 |
-
# Load the model using the custom configuration
|
| 18 |
-
self.model = AutoModelForTokenClassification.from_pretrained(dir_model, config=config)
|
| 19 |
-
|
| 20 |
-
# Initialize the pipeline with the model and tokenizer
|
| 21 |
-
# Use a pre-trained tokenizer compatible with your model
|
| 22 |
-
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
| 23 |
-
# Use a pipeline appropriate for your task. Here we use "token-classification" for NER (Named Entity Recognition).
|
| 24 |
-
self.pipeline = pipeline("token-classification", model=path, tokenizer=self.tokenizer)
|
| 25 |
|
| 26 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 27 |
"""
|
| 28 |
Args:
|
| 29 |
data (Dict[str, Any]): The input data including:
|
| 30 |
- "inputs": The text input from which to extract information.
|
|
|
|
| 31 |
|
| 32 |
Returns:
|
| 33 |
-
List[Dict[str, Any]]: The extracted
|
| 34 |
"""
|
| 35 |
-
# Get inputs
|
| 36 |
inputs = data.get("inputs", "")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
+
from gliner import GLiNER
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
class EndpointHandler:
|
|
|
|
| 6 |
def __init__(self, path=""):
|
| 7 |
+
# Initialize the GLiNER model
|
| 8 |
+
self.model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 11 |
"""
|
| 12 |
Args:
|
| 13 |
data (Dict[str, Any]): The input data including:
|
| 14 |
- "inputs": The text input from which to extract information.
|
| 15 |
+
- "labels": The labels to predict entities for.
|
| 16 |
|
| 17 |
Returns:
|
| 18 |
+
List[Dict[str, Any]]: The extracted entities from the text, formatted as required.
|
| 19 |
"""
|
| 20 |
+
# Get inputs and labels
|
| 21 |
inputs = data.get("inputs", "")
|
| 22 |
+
labels = data.get("labels", [])
|
| 23 |
+
|
| 24 |
+
# Predict entities using GLiNER
|
| 25 |
+
entities = self.model.predict_entities(inputs, labels)
|
| 26 |
+
|
| 27 |
+
# Format the results to match the expected output structure
|
| 28 |
+
formatted_results = []
|
| 29 |
+
for entity in entities:
|
| 30 |
+
formatted_entity = {
|
| 31 |
+
"word": entity["text"],
|
| 32 |
+
"entity_group": entity["label"], # Assuming entity["label"] contains the label
|
| 33 |
+
"score": entity.get("score", 1.0) # Assuming a default score of 1.0 if not provided
|
| 34 |
+
}
|
| 35 |
+
formatted_results.append(formatted_entity)
|
| 36 |
+
|
| 37 |
+
return formatted_results
|