|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import networkx as nx |
|
|
import torch |
|
|
import math |
|
|
import re |
|
|
import json |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
self.model_name = "Babelscape/rebel-large" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) |
|
|
|
|
|
|
|
|
self.pattern1 = re.compile('<pad>|<s>|</s>') |
|
|
self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)') |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Handler method for processing incoming requests. |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
if not isinstance(inputs, list): |
|
|
inputs = [inputs] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for text in inputs: |
|
|
graph = self.text_to_graph(text) |
|
|
relations = self.graph_to_relations(graph) |
|
|
results.append({"relations": relations}) |
|
|
|
|
|
return {"results": results} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph: |
|
|
""" |
|
|
Convert input text to a graph representation using the REBEL model. |
|
|
""" |
|
|
inputs = self.tokenizer([text], return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
num_tokens = len(inputs["input_ids"][0]) |
|
|
num_spans = math.ceil(num_tokens / span_length) |
|
|
overlap = math.ceil((num_spans * span_length - num_tokens) / |
|
|
max(num_spans - 1, 1)) |
|
|
|
|
|
|
|
|
spans_boundaries = [] |
|
|
start = 0 |
|
|
for i in range(num_spans): |
|
|
spans_boundaries.append([start + span_length * i, |
|
|
start + span_length * (i + 1)]) |
|
|
start -= overlap |
|
|
|
|
|
|
|
|
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] |
|
|
for boundary in spans_boundaries] |
|
|
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] |
|
|
for boundary in spans_boundaries] |
|
|
|
|
|
inputs = { |
|
|
"input_ids": torch.stack(tensor_ids).to(self.device), |
|
|
"attention_mask": torch.stack(tensor_masks).to(self.device) |
|
|
} |
|
|
|
|
|
|
|
|
num_return_sequences = 3 |
|
|
gen_kwargs = { |
|
|
"max_length": 256, |
|
|
"length_penalty": 0, |
|
|
"num_beams": 3, |
|
|
"num_return_sequences": num_return_sequences |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_tokens = self.model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
decoded_preds = self.tokenizer.batch_decode(generated_tokens, |
|
|
skip_special_tokens=False) |
|
|
|
|
|
|
|
|
graph = nx.DiGraph() |
|
|
for i, sentence_pred in enumerate(decoded_preds): |
|
|
current_span_index = i // num_return_sequences |
|
|
relations = self.extract_relations_from_model_output(sentence_pred) |
|
|
for relation in relations: |
|
|
relation["meta"] = {"spans": [spans_boundaries[current_span_index]]} |
|
|
self.add_relation_to_graph(graph, relation) |
|
|
|
|
|
return graph |
|
|
|
|
|
def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Extract relations from the model's output text. |
|
|
""" |
|
|
relations = [] |
|
|
subject, relation, object_ = '', '', '' |
|
|
text = text.strip() |
|
|
current = None |
|
|
|
|
|
text_replaced = self.pattern1.sub('', text) |
|
|
text_replaced = self.pattern2.sub(' \g<1> ', text_replaced) |
|
|
|
|
|
for token in text_replaced.split(): |
|
|
if token == "<triplet>": |
|
|
current = 'subj' |
|
|
if subject and relation and object_: |
|
|
relations.append({ |
|
|
'head': subject.strip(), |
|
|
'type': relation.strip(), |
|
|
'tail': object_.strip() |
|
|
}) |
|
|
subject, relation, object_ = '', '', '' |
|
|
elif token == "<subj>": |
|
|
current = 'obj' |
|
|
if subject and relation and object_: |
|
|
relations.append({ |
|
|
'head': subject.strip(), |
|
|
'type': relation.strip(), |
|
|
'tail': object_.strip() |
|
|
}) |
|
|
relation, object_ = '', '' |
|
|
elif token == "<obj>": |
|
|
current = 'rel' |
|
|
else: |
|
|
if current == 'subj': |
|
|
subject += ' ' + token |
|
|
elif current == 'rel': |
|
|
relation += ' ' + token |
|
|
elif current == 'obj': |
|
|
object_ += ' ' + token |
|
|
|
|
|
if subject and relation and object_: |
|
|
relations.append({ |
|
|
'head': subject.strip(), |
|
|
'type': relation.strip(), |
|
|
'tail': object_.strip() |
|
|
}) |
|
|
|
|
|
return relations |
|
|
|
|
|
def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None: |
|
|
""" |
|
|
Add a relation to the graph. |
|
|
""" |
|
|
head, tail = relation['head'], relation['tail'] |
|
|
relation_type = relation['type'] |
|
|
span = relation.get('meta', {}).get('spans', []) |
|
|
|
|
|
if graph.has_edge(head, tail) and relation_type in graph[head][tail]: |
|
|
existing_spans = graph[head][tail][relation_type]['spans'] |
|
|
new_spans = [s for s in span if s not in existing_spans] |
|
|
graph[head][tail][relation_type]['spans'].extend(new_spans) |
|
|
else: |
|
|
graph.add_edge(head, tail, relation=relation_type, spans=span) |
|
|
|
|
|
def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Convert a NetworkX graph to a list of relations. |
|
|
""" |
|
|
relations = [] |
|
|
for u, v, data in graph.edges(data=True): |
|
|
relations.append({ |
|
|
"head": u, |
|
|
"type": data["relation"], |
|
|
"tail": v |
|
|
}) |
|
|
return relations |
|
|
|