| import torch |
| import subprocess |
| import sys |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| try: |
| import gliner |
| except ImportError: |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "gliner==0.2.8"]) |
| import gliner |
| |
| from gliner import GLiNER |
| self.model = GLiNER.from_pretrained(path) |
| self.model = self.model.to("cuda") |
| self.model.eval() |
|
|
| def __call__(self, data): |
| if isinstance(data, dict) and "inputs" in data: |
| data = data["inputs"] |
|
|
| text = data.get("text", "") |
| labels = data.get("labels", []) |
|
|
| if not text or not labels: |
| return {"error": "Please provide 'text' and 'labels'"} |
|
|
| entities = self.model.predict_entities(text, labels) |
| return {"entities": entities} |
|
|