File size: 15,407 Bytes
71303dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
from flask import Flask, render_template, request, jsonify
import spacy
import json
import requests
from gliner import GLiNER

app = Flask(__name__)

# Load a blank English spaCy pipeline for tokenization
nlp = spacy.blank("en")

# GLiNER pipeline (will be configured on first use)
gliner_nlp = None

# GLiNER multitask model for relationships
gliner_multitask = None

def get_or_create_multitask_model():
    """
    Get or create GLiNER multitask model for relationship extraction
    """
    global gliner_multitask
    
    if gliner_multitask is None:
        try:
            gliner_multitask = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5")
        except Exception as e:
            print(f"Error loading GLiNER multitask model: {e}")
            return None
    
    return gliner_multitask

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/tokenize', methods=['POST'])
def tokenize_text():
    """
    Tokenize the input text and return token boundaries
    """
    data = request.get_json()
    text = data.get('text', '')
    
    if not text:
        return jsonify({'error': 'No text provided'}), 400
    
    # Process text with spaCy
    doc = nlp(text)
    
    # Extract token information
    tokens = []
    for token in doc:
        tokens.append({
            'text': token.text,
            'start': token.idx,
            'end': token.idx + len(token.text)
        })
    
    return jsonify({
        'tokens': tokens,
        'text': text
    })

@app.route('/find_token_boundaries', methods=['POST'])
def find_token_boundaries():
    """
    Given a text selection, find the token boundaries that encompass it
    """
    data = request.get_json()
    text = data.get('text', '')
    start = data.get('start', 0)
    end = data.get('end', 0)
    label = data.get('label', 'UNLABELED')
    
    if not text:
        return jsonify({'error': 'No text provided'}), 400
    
    # Process text with spaCy
    doc = nlp(text)
    
    # Find tokens that overlap with the selection
    token_start = None
    token_end = None
    
    for token in doc:
        # Check if token overlaps with selection
        if token.idx < end and token.idx + len(token.text) > start:
            if token_start is None:
                token_start = token.idx
            token_end = token.idx + len(token.text)
    
    # If no tokens found, return original boundaries
    if token_start is None:
        token_start = start
        token_end = end
    
    return jsonify({
        'start': token_start,
        'end': token_end,
        'selected_text': text[token_start:token_end],
        'label': label
    })

@app.route('/get_default_labels', methods=['GET'])
def get_default_labels():
    """
    Return the default annotation labels with their colors
    """
    default_labels = [
        {'name': 'PERSON', 'color': '#fef3c7', 'border': '#f59e0b'},
        {'name': 'LOCATION', 'color': '#dbeafe', 'border': '#3b82f6'},
        {'name': 'ORGANIZATION', 'color': '#dcfce7', 'border': '#10b981'}
    ]
    
    return jsonify({'labels': default_labels})

@app.route('/get_default_relationship_labels', methods=['GET'])
def get_default_relationship_labels():
    """
    Return the default relationship labels with their colors
    """
    default_relationship_labels = [
        {'name': 'worked at', 'color': '#fce7f3', 'border': '#ec4899'},
        {'name': 'visited', 'color': '#f3e8ff', 'border': '#a855f7'}
    ]
    
    return jsonify({'relationship_labels': default_relationship_labels})

def get_or_create_gliner_pipeline(labels):
    """
    Get or create GLiNER pipeline with specified labels
    """
    global gliner_nlp
    
    # Convert labels to lowercase for GLiNER
    gliner_labels = [label.lower() for label in labels]
    
    try:
        # Create new pipeline if it doesn't exist or labels changed
        custom_spacy_config = {
            "gliner_model": "gliner-community/gliner_small-v2.5",
            "chunk_size": 250,
            "labels": gliner_labels,
            "style": "ent"
        }
        
        gliner_nlp = spacy.blank("en")
        gliner_nlp.add_pipe("gliner_spacy", config=custom_spacy_config)
        
        return gliner_nlp
    except Exception as e:
        print(f"Error creating GLiNER pipeline: {e}")
        return None

@app.route('/run_gliner', methods=['POST'])
def run_gliner():
    """
    Run GLiNER entity extraction on the provided text with specified labels
    """
    data = request.get_json()
    text = data.get('text', '')
    labels = data.get('labels', [])
    
    if not text:
        return jsonify({'error': 'No text provided'}), 400
    
    if not labels:
        return jsonify({'error': 'No labels provided'}), 400
    
    try:
        # Get or create GLiNER pipeline
        pipeline = get_or_create_gliner_pipeline(labels)
        
        if pipeline is None:
            return jsonify({'error': 'Failed to initialize GLiNER pipeline'}), 500
        
        # Process text with GLiNER
        doc = pipeline(text)
        
        # Extract entities with token boundaries
        entities = []
        for ent in doc.ents:
            # Map GLiNER label back to user's label format
            original_label = None
            for label in labels:
                if label.lower() == ent.label_.lower():
                    original_label = label
                    break
            
            if original_label:
                entities.append({
                    'text': ent.text,
                    'start': ent.start_char,
                    'end': ent.end_char,
                    'label': original_label,
                    'confidence': getattr(ent, 'score', 1.0) if hasattr(ent, 'score') else 1.0
                })
        
        return jsonify({
            'entities': entities,
            'total_found': len(entities)
        })
        
    except Exception as e:
        print(f"GLiNER processing error: {e}")
        return jsonify({'error': f'GLiNER processing failed: {str(e)}'}), 500

@app.route('/run_gliner_relationships', methods=['POST'])
def run_gliner_relationships():
    """
    Run GLiNER relationship extraction on the provided text with specified relationship labels
    """
    data = request.get_json()
    text = data.get('text', '')
    relationship_labels = data.get('relationship_labels', [])
    entity_labels = data.get('entity_labels', ["person", "organization", "location", "date", "place"])
    
    if not text:
        return jsonify({'error': 'No text provided'}), 400
    
    if not relationship_labels:
        return jsonify({'error': 'No relationship labels provided'}), 400
    
    try:
        # Get GLiNER multitask model
        model = get_or_create_multitask_model()
        
        if model is None:
            return jsonify({'error': 'Failed to initialize GLiNER multitask model'}), 500
        
        # First extract entities using the provided entity labels
        print(f"Using entity labels: {entity_labels}")
        entities = model.predict_entities(text, entity_labels, threshold=0.3)
        print(entities)
        
        # Then extract relationships using the specific format
        formatted_labels = []
        for label in relationship_labels:
            for entity_label in entity_labels:
                formatted_labels.append(f"{entity_label} <> {label}")
        
        print(f"Formatted relationship labels: {formatted_labels}")
        
        relation_entities = model.predict_entities(text, formatted_labels, threshold=0.3)
        
        # Process results into relationship triplets
        relationships = []
        
        # Group relation entities by their relation type and try to find entity pairs
        for rel_entity in relation_entities:
            print(rel_entity)
            label_parts = rel_entity['label'].split(' <> ')
            if len(label_parts) == 2:
                entity_type, relation_type = label_parts
                
                # Find potential subject and object entities near this relation
                rel_start = rel_entity['start']
                rel_end = rel_entity['end']
                
                # Look for entities before and after the relation mention
                subject_candidates = [e for e in entities if e['end'] <= rel_start and abs(e['end'] - rel_start) < 100]
                object_candidates = [e for e in entities if e['start'] >= rel_end and abs(e['start'] - rel_end) < 100]
                
                # Also look for entities that contain or are contained by the relation text
                overlapping_entities = [e for e in entities if 
                    (e['start'] <= rel_start and e['end'] >= rel_end) or  # entity contains relation
                    (rel_start <= e['start'] and rel_end >= e['end'])     # relation contains entity
                ]
                
                if subject_candidates and object_candidates:
                    # Take the closest entities
                    subject = max(subject_candidates, key=lambda x: x['end'])
                    object_entity = min(object_candidates, key=lambda x: x['start'])
                    
                    relationships.append({
                        'subject': subject['text'],
                        'subject_start': subject['start'],
                        'subject_end': subject['end'],
                        'relation_type': relation_type,
                        'relation_text': rel_entity['text'],
                        'relation_start': rel_entity['start'],
                        'relation_end': rel_entity['end'],
                        'object': object_entity['text'],
                        'object_start': object_entity['start'],
                        'object_end': object_entity['end'],
                        'confidence': rel_entity['score'],
                        'full_text': f"{subject['text']} {relation_type} {object_entity['text']}"
                    })
                elif overlapping_entities:
                    # Handle cases where the relation text spans or overlaps with entities
                    for ent in overlapping_entities:
                        relationships.append({
                            'subject': ent['text'],
                            'subject_start': ent['start'],
                            'subject_end': ent['end'],
                            'relation_type': relation_type,
                            'relation_text': rel_entity['text'],
                            'relation_start': rel_entity['start'],
                            'relation_end': rel_entity['end'],
                            'object': '',  # Will be filled by user or further processing
                            'object_start': -1,
                            'object_end': -1,
                            'confidence': rel_entity['score'],
                            'full_text': f"{ent['text']} {relation_type} [object]"
                        })
        
        return jsonify({
            'relationships': relationships,
            'total_found': len(relationships)
        })
        
    except Exception as e:
        print(f"GLiNER relationship processing error: {e}")
        return jsonify({'error': f'GLiNER relationship processing failed: {str(e)}'}), 500

@app.route('/search_wikidata', methods=['POST'])
def search_wikidata():
    """
    Search Wikidata for entities matching the query
    """
    data = request.get_json()
    query = data.get('query', '').strip()
    limit = data.get('limit', 10)
    
    if not query:
        return jsonify({'error': 'No query provided'}), 400
    
    try:
        # Wikidata search API endpoint
        url = 'https://www.wikidata.org/w/api.php'
        
        params = {
            'action': 'wbsearchentities',
            'search': query,
            'language': 'en',
            'format': 'json',
            'limit': limit,
            'type': 'item'
        }
        
        headers = {
            'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests'
        }
        
        response = requests.get(url, params=params, headers=headers, timeout=10)
        response.raise_for_status()
        
        data = response.json()
        
        # Extract relevant information
        results = []
        if 'search' in data:
            for item in data['search']:
                result = {
                    'id': item.get('id', ''),
                    'label': item.get('label', ''),
                    'description': item.get('description', ''),
                    'url': f"https://www.wikidata.org/wiki/{item.get('id', '')}"
                }
                results.append(result)
        
        return jsonify({
            'results': results,
            'total': len(results)
        })
        
    except requests.exceptions.RequestException as e:
        print(f"Wikidata API error: {e}")
        return jsonify({'error': 'Failed to search Wikidata'}), 500
    except Exception as e:
        print(f"Wikidata search error: {e}")
        return jsonify({'error': f'Search failed: {str(e)}'}), 500

@app.route('/get_wikidata_entity', methods=['POST'])
def get_wikidata_entity():
    """
    Get Wikidata entity information by Q-code
    """
    data = request.get_json()
    qcode = data.get('qcode', '').strip()
    
    if not qcode:
        return jsonify({'error': 'No Q-code provided'}), 400
    
    # Ensure Q-code format
    if not qcode.startswith('Q'):
        qcode = 'Q' + qcode.lstrip('Q')
    
    try:
        # Wikidata entity API endpoint
        url = 'https://www.wikidata.org/w/api.php'
        
        params = {
            'action': 'wbgetentities',
            'ids': qcode,
            'languages': 'en',
            'format': 'json'
        }
        
        headers = {
            'User-Agent': 'AnnotationTool/1.0 (https://github.com/user/annotation-tool) Python/requests'
        }
        
        response = requests.get(url, params=params, headers=headers, timeout=10)
        response.raise_for_status()
        
        data = response.json()
        
        if 'entities' in data and qcode in data['entities']:
            entity = data['entities'][qcode]
            
            if 'missing' in entity:
                return jsonify({'error': f'Entity {qcode} not found'}), 404
            
            # Extract information
            result = {
                'id': qcode,
                'label': entity.get('labels', {}).get('en', {}).get('value', ''),
                'description': entity.get('descriptions', {}).get('en', {}).get('value', ''),
                'url': f"https://www.wikidata.org/wiki/{qcode}"
            }
            
            return jsonify({'entity': result})
        else:
            return jsonify({'error': f'Entity {qcode} not found'}), 404
            
    except requests.exceptions.RequestException as e:
        print(f"Wikidata API error: {e}")
        return jsonify({'error': 'Failed to get Wikidata entity'}), 500
    except Exception as e:
        print(f"Wikidata entity error: {e}")
        return jsonify({'error': f'Request failed: {str(e)}'}), 500

if __name__ == '__main__':
    import os
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port, debug=False)