File size: 6,693 Bytes
55cb0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import networkx as nx
import torch
import math
import re
import json
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path: str = ""):
# Load model and tokenizer during initialization
self.model_name = "Babelscape/rebel-large"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
# Compile regex patterns
self.pattern1 = re.compile('<pad>|<s>|</s>')
self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)')
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handler method for processing incoming requests.
"""
try:
# Extract text from the input data
inputs = data.pop("inputs", data)
if not isinstance(inputs, list):
inputs = [inputs]
# Process each text input
results = []
for text in inputs:
graph = self.text_to_graph(text)
relations = self.graph_to_relations(graph)
results.append({"relations": relations})
return {"results": results}
except Exception as e:
return {"error": str(e)}
def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph:
"""
Convert input text to a graph representation using the REBEL model.
"""
inputs = self.tokenizer([text], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
num_tokens = len(inputs["input_ids"][0])
num_spans = math.ceil(num_tokens / span_length)
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
# Calculate span boundaries
spans_boundaries = []
start = 0
for i in range(num_spans):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
# Process each span
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids).to(self.device),
"attention_mask": torch.stack(tensor_masks).to(self.device)
}
# Generate predictions
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
with torch.no_grad():
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
decoded_preds = self.tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# Build graph from predictions
graph = nx.DiGraph()
for i, sentence_pred in enumerate(decoded_preds):
current_span_index = i // num_return_sequences
relations = self.extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {"spans": [spans_boundaries[current_span_index]]}
self.add_relation_to_graph(graph, relation)
return graph
def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]:
"""
Extract relations from the model's output text.
"""
relations = []
subject, relation, object_ = '', '', ''
text = text.strip()
current = None
text_replaced = self.pattern1.sub('', text)
text_replaced = self.pattern2.sub(' \g<1> ', text_replaced)
for token in text_replaced.split():
if token == "<triplet>":
current = 'subj'
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
subject, relation, object_ = '', '', ''
elif token == "<subj>":
current = 'obj'
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation, object_ = '', ''
elif token == "<obj>":
current = 'rel'
else:
if current == 'subj':
subject += ' ' + token
elif current == 'rel':
relation += ' ' + token
elif current == 'obj':
object_ += ' ' + token
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None:
"""
Add a relation to the graph.
"""
head, tail = relation['head'], relation['tail']
relation_type = relation['type']
span = relation.get('meta', {}).get('spans', [])
if graph.has_edge(head, tail) and relation_type in graph[head][tail]:
existing_spans = graph[head][tail][relation_type]['spans']
new_spans = [s for s in span if s not in existing_spans]
graph[head][tail][relation_type]['spans'].extend(new_spans)
else:
graph.add_edge(head, tail, relation=relation_type, spans=span)
def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]:
"""
Convert a NetworkX graph to a list of relations.
"""
relations = []
for u, v, data in graph.edges(data=True):
relations.append({
"head": u,
"type": data["relation"],
"tail": v
})
return relations
|