Token Classification
GLiNER
PyTorch
multilingual
bert
File size: 1,257 Bytes
4e97a81
1903c5a
4e97a81
1903c5a
39b1f14
1903c5a
 
4e97a81
 
 
 
 
 
1903c5a
4e97a81
 
1903c5a
4e97a81
1903c5a
4e97a81
b53706a
7ccef2b
1903c5a
 
 
 
 
 
 
db7185a
1903c5a
7ccef2b
1903c5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Dict, List, Any
from gliner import GLiNER

class EndpointHandler:
    def __init__(self, path=""):
        # Initialize the GLiNER model
        self.model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data (Dict[str, Any]): The input data including:
                - "inputs": The text input from which to extract information.
                - "labels": The labels to predict entities for.

        Returns:
            List[Dict[str, Any]]: The extracted entities from the text, formatted as required.
        """
        # Get inputs and labels
        inputs = data.get("inputs", "")
        labels = ["party", "document title"]
        print('labels',labels)
        # Predict entities using GLiNER
        entities = self.model.predict_entities(inputs, labels)

        # Format the results to match the expected output structure
        formatted_results = []
        for entity in entities:
            formatted_entity = {
               entity["label"]: entity["text"],
            }
            print(formatted_entity)
            formatted_results.append(formatted_entity)

        return formatted_results