ethanthoma commited on
Commit
55cb0d8
·
1 Parent(s): 6a7dba3

add custom pipeline

Browse files
__pycache__/handler.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
handler.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import networkx as nx
3
+ import torch
4
+ import math
5
+ import re
6
+ import json
7
+ from typing import Dict, List, Any
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path: str = ""):
11
+ # Load model and tokenizer during initialization
12
+ self.model_name = "Babelscape/rebel-large"
13
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
14
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
15
+
16
+ # Compile regex patterns
17
+ self.pattern1 = re.compile('<pad>|<s>|</s>')
18
+ self.pattern2 = re.compile('(<obj>|<subj>|<triplet>)')
19
+
20
+ # Set device
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.model = self.model.to(self.device)
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
+ """
26
+ Handler method for processing incoming requests.
27
+ """
28
+ try:
29
+ # Extract text from the input data
30
+ inputs = data.pop("inputs", data)
31
+ if not isinstance(inputs, list):
32
+ inputs = [inputs]
33
+
34
+ # Process each text input
35
+ results = []
36
+ for text in inputs:
37
+ graph = self.text_to_graph(text)
38
+ relations = self.graph_to_relations(graph)
39
+ results.append({"relations": relations})
40
+
41
+ return {"results": results}
42
+
43
+ except Exception as e:
44
+ return {"error": str(e)}
45
+
46
+ def text_to_graph(self, text: str, span_length: int = 128) -> nx.DiGraph:
47
+ """
48
+ Convert input text to a graph representation using the REBEL model.
49
+ """
50
+ inputs = self.tokenizer([text], return_tensors="pt")
51
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
52
+
53
+ num_tokens = len(inputs["input_ids"][0])
54
+ num_spans = math.ceil(num_tokens / span_length)
55
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
56
+ max(num_spans - 1, 1))
57
+
58
+ # Calculate span boundaries
59
+ spans_boundaries = []
60
+ start = 0
61
+ for i in range(num_spans):
62
+ spans_boundaries.append([start + span_length * i,
63
+ start + span_length * (i + 1)])
64
+ start -= overlap
65
+
66
+ # Process each span
67
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
68
+ for boundary in spans_boundaries]
69
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
70
+ for boundary in spans_boundaries]
71
+
72
+ inputs = {
73
+ "input_ids": torch.stack(tensor_ids).to(self.device),
74
+ "attention_mask": torch.stack(tensor_masks).to(self.device)
75
+ }
76
+
77
+ # Generate predictions
78
+ num_return_sequences = 3
79
+ gen_kwargs = {
80
+ "max_length": 256,
81
+ "length_penalty": 0,
82
+ "num_beams": 3,
83
+ "num_return_sequences": num_return_sequences
84
+ }
85
+
86
+ with torch.no_grad():
87
+ generated_tokens = self.model.generate(**inputs, **gen_kwargs)
88
+
89
+ decoded_preds = self.tokenizer.batch_decode(generated_tokens,
90
+ skip_special_tokens=False)
91
+
92
+ # Build graph from predictions
93
+ graph = nx.DiGraph()
94
+ for i, sentence_pred in enumerate(decoded_preds):
95
+ current_span_index = i // num_return_sequences
96
+ relations = self.extract_relations_from_model_output(sentence_pred)
97
+ for relation in relations:
98
+ relation["meta"] = {"spans": [spans_boundaries[current_span_index]]}
99
+ self.add_relation_to_graph(graph, relation)
100
+
101
+ return graph
102
+
103
+ def extract_relations_from_model_output(self, text: str) -> List[Dict[str, str]]:
104
+ """
105
+ Extract relations from the model's output text.
106
+ """
107
+ relations = []
108
+ subject, relation, object_ = '', '', ''
109
+ text = text.strip()
110
+ current = None
111
+
112
+ text_replaced = self.pattern1.sub('', text)
113
+ text_replaced = self.pattern2.sub(' \g<1> ', text_replaced)
114
+
115
+ for token in text_replaced.split():
116
+ if token == "<triplet>":
117
+ current = 'subj'
118
+ if subject and relation and object_:
119
+ relations.append({
120
+ 'head': subject.strip(),
121
+ 'type': relation.strip(),
122
+ 'tail': object_.strip()
123
+ })
124
+ subject, relation, object_ = '', '', ''
125
+ elif token == "<subj>":
126
+ current = 'obj'
127
+ if subject and relation and object_:
128
+ relations.append({
129
+ 'head': subject.strip(),
130
+ 'type': relation.strip(),
131
+ 'tail': object_.strip()
132
+ })
133
+ relation, object_ = '', ''
134
+ elif token == "<obj>":
135
+ current = 'rel'
136
+ else:
137
+ if current == 'subj':
138
+ subject += ' ' + token
139
+ elif current == 'rel':
140
+ relation += ' ' + token
141
+ elif current == 'obj':
142
+ object_ += ' ' + token
143
+
144
+ if subject and relation and object_:
145
+ relations.append({
146
+ 'head': subject.strip(),
147
+ 'type': relation.strip(),
148
+ 'tail': object_.strip()
149
+ })
150
+
151
+ return relations
152
+
153
+ def add_relation_to_graph(self, graph: nx.DiGraph, relation: Dict[str, Any]) -> None:
154
+ """
155
+ Add a relation to the graph.
156
+ """
157
+ head, tail = relation['head'], relation['tail']
158
+ relation_type = relation['type']
159
+ span = relation.get('meta', {}).get('spans', [])
160
+
161
+ if graph.has_edge(head, tail) and relation_type in graph[head][tail]:
162
+ existing_spans = graph[head][tail][relation_type]['spans']
163
+ new_spans = [s for s in span if s not in existing_spans]
164
+ graph[head][tail][relation_type]['spans'].extend(new_spans)
165
+ else:
166
+ graph.add_edge(head, tail, relation=relation_type, spans=span)
167
+
168
+ def graph_to_relations(self, graph: nx.DiGraph) -> List[Dict[str, str]]:
169
+ """
170
+ Convert a NetworkX graph to a list of relations.
171
+ """
172
+ relations = []
173
+ for u, v, data in graph.edges(data=True):
174
+ relations.append({
175
+ "head": u,
176
+ "type": data["relation"],
177
+ "tail": v
178
+ })
179
+ return relations
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0
3
+ networkx>=3.1
4
+ holidays>=0.25
5
+ numpy>=1.24.0
6
+ regex>=2023.0.0
7
+ h5py>=3.8.0
8
+ pandas>=2.0.0