from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import networkx as nx import torch import math import re import json from typing import Dict, List, Any class EndpointHandler: def __init__(self, path: str = ""): # Load model and tokenizer during initialization self.model_name = "Babelscape/rebel-large" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) # Compile regex patterns self.pattern1 = re.compile('||') self.pattern2 = re.compile('(||)') # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Handler method for processing incoming requests. """ try: # Extract text from the input data inputs = data.pop("inputs", data) if not isinstance(inputs, list): inputs = [inputs] # Process each text input results = [] for text in inputs: graph = self.text_to_graph(text) relations = self.graph_to_relations(graph) results.append({"relations": relations}) return {"results": results} except Exception as e: return {"error": str(e)} def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph: """ Convert input text to a graph representation using the REBEL model. """ inputs = self.tokenizer([text], return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} num_tokens = len(inputs["input_ids"][0]) num_spans = math.ceil(num_tokens / span_length) overlap = math.ceil((num_spans * span_length - num_tokens) / max(num_spans - 1, 1)) # Calculate span boundaries spans_boundaries = [] start = 0 for i in range(num_spans): spans_boundaries.append([start + span_length * i, start + span_length * (i + 1)]) start -= overlap # Process each span tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] inputs = { "input_ids": torch.stack(tensor_ids).to(self.device), "attention_mask": torch.stack(tensor_masks).to(self.device) } # Generate predictions num_return_sequences = 3 gen_kwargs = { "max_length": 256, "length_penalty": 0, "num_beams": 3, "num_return_sequences": num_return_sequences } with torch.no_grad(): generated_tokens = self.model.generate(**inputs, **gen_kwargs) decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # Build graph from predictions graph = nx.DiGraph() for i, sentence_pred in enumerate(decoded_preds): current_span_index = i // num_return_sequences relations = self.extract_relations_from_model_output(sentence_pred) for relation in relations: relation["meta"] = {"spans": [spans_boundaries[current_span_index]]} self.add_relation_to_graph(graph, relation) return graph def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]: """ Extract relations from the model's output text. """ relations = [] subject, relation, object_ = '', '', '' text = text.strip() current = None text_replaced = self.pattern1.sub('', text) text_replaced = self.pattern2.sub(' \g<1> ', text_replaced) for token in text_replaced.split(): if token == "": current = 'subj' if subject and relation and object_: relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) subject, relation, object_ = '', '', '' elif token == "": current = 'obj' if subject and relation and object_: relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) relation, object_ = '', '' elif token == "": current = 'rel' else: if current == 'subj': subject += ' ' + token elif current == 'rel': relation += ' ' + token elif current == 'obj': object_ += ' ' + token if subject and relation and object_: relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) return relations def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None: """ Add a relation to the graph. """ head, tail = relation['head'], relation['tail'] relation_type = relation['type'] span = relation.get('meta', {}).get('spans', []) if graph.has_edge(head, tail) and relation_type in graph[head][tail]: existing_spans = graph[head][tail][relation_type]['spans'] new_spans = [s for s in span if s not in existing_spans] graph[head][tail][relation_type]['spans'].extend(new_spans) else: graph.add_edge(head, tail, relation=relation_type, spans=span) def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]: """ Convert a NetworkX graph to a list of relations. """ relations = [] for u, v, data in graph.edges(data=True): relations.append({ "head": u, "type": data["relation"], "tail": v }) return relations