Token Classification
GLiNER
PyTorch
multilingual
bert
Rejebc commited on
Commit
1903c5a
·
verified ·
1 Parent(s): d2dffba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -28
handler.py CHANGED
@@ -1,42 +1,37 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline, AutoConfig, AutoModelForTokenClassification, AutoTokenizer, BertTokenizerFast
3
  import os
4
 
5
-
6
- class EndpointHandler():
7
  def __init__(self, path=""):
8
- dir_model = "urchade/gliner_multi-v2.1"
9
-
10
- config_path = os.path.join(path, "gliner_config.json")
11
- if not os.path.exists(config_path):
12
- raise FileNotFoundError(f"Custom configuration file not found at {config_path}")
13
-
14
- # Load the custom configuration
15
- config = AutoConfig.from_pretrained(config_path)
16
-
17
- # Load the model using the custom configuration
18
- self.model = AutoModelForTokenClassification.from_pretrained(dir_model, config=config)
19
-
20
- # Initialize the pipeline with the model and tokenizer
21
- # Use a pre-trained tokenizer compatible with your model
22
- self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
23
- # Use a pipeline appropriate for your task. Here we use "token-classification" for NER (Named Entity Recognition).
24
- self.pipeline = pipeline("token-classification", model=path, tokenizer=self.tokenizer)
25
 
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
  """
28
  Args:
29
  data (Dict[str, Any]): The input data including:
30
  - "inputs": The text input from which to extract information.
 
31
 
32
  Returns:
33
- List[Dict[str, Any]]: The extracted information from the text.
34
  """
35
- # Get inputs
36
  inputs = data.get("inputs", "")
37
-
38
- # Run the pipeline for text extraction
39
- extraction_results = self.pipeline(inputs)
40
-
41
- # Process and return the results as needed
42
- return extraction_results
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
+ from gliner import GLiNER
3
  import os
4
 
5
+ class EndpointHandler:
 
6
  def __init__(self, path=""):
7
+ # Initialize the GLiNER model
8
+ self.model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
  """
12
  Args:
13
  data (Dict[str, Any]): The input data including:
14
  - "inputs": The text input from which to extract information.
15
+ - "labels": The labels to predict entities for.
16
 
17
  Returns:
18
+ List[Dict[str, Any]]: The extracted entities from the text, formatted as required.
19
  """
20
+ # Get inputs and labels
21
  inputs = data.get("inputs", "")
22
+ labels = data.get("labels", [])
23
+
24
+ # Predict entities using GLiNER
25
+ entities = self.model.predict_entities(inputs, labels)
26
+
27
+ # Format the results to match the expected output structure
28
+ formatted_results = []
29
+ for entity in entities:
30
+ formatted_entity = {
31
+ "word": entity["text"],
32
+ "entity_group": entity["label"], # Assuming entity["label"] contains the label
33
+ "score": entity.get("score", 1.0) # Assuming a default score of 1.0 if not provided
34
+ }
35
+ formatted_results.append(formatted_entity)
36
+
37
+ return formatted_results