Yassine commited on
Commit
a966ebf
·
1 Parent(s): e0a7220

fix extraction

Browse files
Files changed (1) hide show
  1. main.py +17 -12
main.py CHANGED
@@ -247,7 +247,8 @@ async def extract_entities(input_data: TextInput):
247
  predictions = outputs.logits.argmax(dim=2)
248
  entities = {}
249
  current_entity = None
250
- current_text = []
 
251
 
252
  word_ids = inputs.word_ids(0)
253
  for idx, word_idx in enumerate(word_ids):
@@ -260,31 +261,35 @@ async def extract_entities(input_data: TextInput):
260
  predicted_label = id2label[prediction]
261
 
262
  if predicted_label.startswith("B-"):
263
- if current_entity:
264
  entity_type = current_entity[2:]
265
  if entity_type not in entities:
266
- entities[entity_type] = [" ".join(current_text)]
267
- # Only keep the first detection, do nothing if already present
 
 
268
 
269
  current_entity = predicted_label
270
- current_text = [clean_tokens[word_idx]]
271
 
272
  elif predicted_label.startswith("I-") and current_entity and predicted_label[2:] == current_entity[2:]:
273
- current_text.append(clean_tokens[word_idx])
 
 
274
 
275
  else:
276
- if current_entity:
277
  entity_type = current_entity[2:]
278
  if entity_type not in entities:
279
- entities[entity_type] = [" ".join(current_text)]
280
- # Only keep the first detection, do nothing if already present
281
  current_entity = None
282
- current_text = []
 
283
 
284
- if current_entity:
285
  entity_type = current_entity[2:]
286
  if entity_type not in entities:
287
- entities[entity_type] = [" ".join(current_text)]
288
  # Only keep the first detection, do nothing if already present
289
 
290
  return {"entities": entities}
 
247
  predictions = outputs.logits.argmax(dim=2)
248
  entities = {}
249
  current_entity = None
250
+ current_start = None
251
+ current_end = None
252
 
253
  word_ids = inputs.word_ids(0)
254
  for idx, word_idx in enumerate(word_ids):
 
261
  predicted_label = id2label[prediction]
262
 
263
  if predicted_label.startswith("B-"):
264
+ if current_entity is not None:
265
  entity_type = current_entity[2:]
266
  if entity_type not in entities:
267
+ entities[entity_type] = [text[current_start:current_end]]
268
+ current_entity = None
269
+ current_start = None
270
+ current_end = None
271
 
272
  current_entity = predicted_label
273
+ current_start, current_end = token_positions[word_idx]
274
 
275
  elif predicted_label.startswith("I-") and current_entity and predicted_label[2:] == current_entity[2:]:
276
+ # Extend the end position to include this token
277
+ _, token_end = token_positions[word_idx]
278
+ current_end = token_end
279
 
280
  else:
281
+ if current_entity is not None:
282
  entity_type = current_entity[2:]
283
  if entity_type not in entities:
284
+ entities[entity_type] = [text[current_start:current_end]]
 
285
  current_entity = None
286
+ current_start = None
287
+ current_end = None
288
 
289
+ if current_entity is not None:
290
  entity_type = current_entity[2:]
291
  if entity_type not in entities:
292
+ entities[entity_type] = [text[current_start:current_end]]
293
  # Only keep the first detection, do nothing if already present
294
 
295
  return {"entities": entities}