|
|
from gliner import GLiNER |
|
|
import torch |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.model = GLiNER.from_pretrained(path) |
|
|
self.model = self.model.to("cuda") |
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
data = data["inputs"] |
|
|
|
|
|
text = data.get("text", "") |
|
|
labels = data.get("labels", []) |
|
|
|
|
|
if not text or not labels: |
|
|
return {"error": "Please provide 'text' and 'labels'"} |
|
|
|
|
|
entities = self.model.predict_entities(text, labels) |
|
|
return {"entities": entities} |
|
|
|