File size: 6,693 Bytes
55cb0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import networkx as nx
import torch
import math
import re
import json
from typing import Dict, List, Any

class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load model and tokenizer during initialization
        self.model_name = "Babelscape/rebel-large"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
        
        # Compile regex patterns
        self.pattern1 = re.compile('<pad>|<s>|</s>')
        self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)')
        
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handler method for processing incoming requests.
        """
        try:
            # Extract text from the input data
            inputs = data.pop("inputs", data)
            if not isinstance(inputs, list):
                inputs = [inputs]
            
            # Process each text input
            results = []
            for text in inputs:
                graph = self.text_to_graph(text)
                relations = self.graph_to_relations(graph)
                results.append({"relations": relations})
            
            return {"results": results}
            
        except Exception as e:
            return {"error": str(e)}

    def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph:
        """
        Convert input text to a graph representation using the REBEL model.
        """
        inputs = self.tokenizer([text], return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        num_tokens = len(inputs["input_ids"][0])
        num_spans = math.ceil(num_tokens / span_length)
        overlap = math.ceil((num_spans * span_length - num_tokens) / 
                          max(num_spans - 1, 1))
        
        # Calculate span boundaries
        spans_boundaries = []
        start = 0
        for i in range(num_spans):
            spans_boundaries.append([start + span_length * i,
                                  start + span_length * (i + 1)])
            start -= overlap

        # Process each span
        tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                     for boundary in spans_boundaries]
        tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                       for boundary in spans_boundaries]
        
        inputs = {
            "input_ids": torch.stack(tensor_ids).to(self.device),
            "attention_mask": torch.stack(tensor_masks).to(self.device)
        }

        # Generate predictions
        num_return_sequences = 3
        gen_kwargs = {
            "max_length": 256,
            "length_penalty": 0,
            "num_beams": 3,
            "num_return_sequences": num_return_sequences
        }
        
        with torch.no_grad():
            generated_tokens = self.model.generate(**inputs, **gen_kwargs)

        decoded_preds = self.tokenizer.batch_decode(generated_tokens,
                                                  skip_special_tokens=False)

        # Build graph from predictions
        graph = nx.DiGraph()
        for i, sentence_pred in enumerate(decoded_preds):
            current_span_index = i // num_return_sequences
            relations = self.extract_relations_from_model_output(sentence_pred)
            for relation in relations:
                relation["meta"] = {"spans": [spans_boundaries[current_span_index]]}
                self.add_relation_to_graph(graph, relation)

        return graph

    def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]:
        """
        Extract relations from the model's output text.
        """
        relations = []
        subject, relation, object_ = '', '', ''
        text = text.strip()
        current = None
        
        text_replaced = self.pattern1.sub('', text)
        text_replaced = self.pattern2.sub(' \g<1> ', text_replaced)

        for token in text_replaced.split():
            if token == "<triplet>":
                current = 'subj'
                if subject and relation and object_:
                    relations.append({
                        'head': subject.strip(),
                        'type': relation.strip(),
                        'tail': object_.strip()
                    })
                    subject, relation, object_ = '', '', ''
            elif token == "<subj>":
                current = 'obj'
                if subject and relation and object_:
                    relations.append({
                        'head': subject.strip(),
                        'type': relation.strip(),
                        'tail': object_.strip()
                    })
                    relation, object_ = '', ''
            elif token == "<obj>":
                current = 'rel'
            else:
                if current == 'subj':
                    subject += ' ' + token
                elif current == 'rel':
                    relation += ' ' + token
                elif current == 'obj':
                    object_ += ' ' + token

        if subject and relation and object_:
            relations.append({
                'head': subject.strip(),
                'type': relation.strip(),
                'tail': object_.strip()
            })

        return relations

    def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None:
        """
        Add a relation to the graph.
        """
        head, tail = relation['head'], relation['tail']
        relation_type = relation['type']
        span = relation.get('meta', {}).get('spans', [])

        if graph.has_edge(head, tail) and relation_type in graph[head][tail]:
            existing_spans = graph[head][tail][relation_type]['spans']
            new_spans = [s for s in span if s not in existing_spans]
            graph[head][tail][relation_type]['spans'].extend(new_spans)
        else:
            graph.add_edge(head, tail, relation=relation_type, spans=span)

    def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]:
        """
        Convert a NetworkX graph to a list of relations.
        """
        relations = []
        for u, v, data in graph.edges(data=True):
            relations.append({
                "head": u,
                "type": data["relation"],
                "tail": v
            })
        return relations