rebel-large / handler.py
ethanthoma's picture
add custom pipeline
55cb0d8
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import networkx as nx
import torch
import math
import re
import json
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path: str = ""):
# Load model and tokenizer during initialization
self.model_name = "Babelscape/rebel-large"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
# Compile regex patterns
self.pattern1 = re.compile('<pad>|<s>|</s>')
self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)')
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handler method for processing incoming requests.
"""
try:
# Extract text from the input data
inputs = data.pop("inputs", data)
if not isinstance(inputs, list):
inputs = [inputs]
# Process each text input
results = []
for text in inputs:
graph = self.text_to_graph(text)
relations = self.graph_to_relations(graph)
results.append({"relations": relations})
return {"results": results}
except Exception as e:
return {"error": str(e)}
def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph:
"""
Convert input text to a graph representation using the REBEL model.
"""
inputs = self.tokenizer([text], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
num_tokens = len(inputs["input_ids"][0])
num_spans = math.ceil(num_tokens / span_length)
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
# Calculate span boundaries
spans_boundaries = []
start = 0
for i in range(num_spans):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
# Process each span
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids).to(self.device),
"attention_mask": torch.stack(tensor_masks).to(self.device)
}
# Generate predictions
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
with torch.no_grad():
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
decoded_preds = self.tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# Build graph from predictions
graph = nx.DiGraph()
for i, sentence_pred in enumerate(decoded_preds):
current_span_index = i // num_return_sequences
relations = self.extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {"spans": [spans_boundaries[current_span_index]]}
self.add_relation_to_graph(graph, relation)
return graph
def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]:
"""
Extract relations from the model's output text.
"""
relations = []
subject, relation, object_ = '', '', ''
text = text.strip()
current = None
text_replaced = self.pattern1.sub('', text)
text_replaced = self.pattern2.sub(' \g<1> ', text_replaced)
for token in text_replaced.split():
if token == "<triplet>":
current = 'subj'
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
subject, relation, object_ = '', '', ''
elif token == "<subj>":
current = 'obj'
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation, object_ = '', ''
elif token == "<obj>":
current = 'rel'
else:
if current == 'subj':
subject += ' ' + token
elif current == 'rel':
relation += ' ' + token
elif current == 'obj':
object_ += ' ' + token
if subject and relation and object_:
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None:
"""
Add a relation to the graph.
"""
head, tail = relation['head'], relation['tail']
relation_type = relation['type']
span = relation.get('meta', {}).get('spans', [])
if graph.has_edge(head, tail) and relation_type in graph[head][tail]:
existing_spans = graph[head][tail][relation_type]['spans']
new_spans = [s for s in span if s not in existing_spans]
graph[head][tail][relation_type]['spans'].extend(new_spans)
else:
graph.add_edge(head, tail, relation=relation_type, spans=span)
def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]:
"""
Convert a NetworkX graph to a list of relations.
"""
relations = []
for u, v, data in graph.edges(data=True):
relations.append({
"head": u,
"type": data["relation"],
"tail": v
})
return relations