Spaces:
Sleeping
Sleeping
| """NER annotation module using GLiNER models.""" | |
| from typing import List, Dict, Union, Optional | |
| import torch | |
| import random | |
| from gliner import GLiNER | |
| from ..utils.text_processing import tokenize_text | |
| class AutoAnnotator: | |
| """A class for automatic NER annotation using GLiNER models.""" | |
| def __init__( | |
| self, | |
| model: str = "BookingCare/gliner-multi-healthcare", | |
| device: Optional[torch.device] = None | |
| ) -> None: | |
| """Initialize the annotator with a GLiNER model. | |
| Args: | |
| model: Name or path of the GLiNER model to use | |
| device: Device to run the model on (CPU/GPU) | |
| """ | |
| if device is None: | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| # Set PyTorch memory management settings | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.set_per_process_memory_fraction(0.8) # Use 80% of available GPU memory | |
| self.model = GLiNER.from_pretrained(model).to(device) | |
| self.annotated_data = [] | |
| self.stat = { | |
| "total": None, | |
| "current": -1 | |
| } | |
| def auto_annotate( | |
| self, | |
| data: List[str], | |
| labels: List[str], | |
| prompt: Optional[Union[str, List[str]]] = None, | |
| threshold: float = 0.5, | |
| nested_ner: bool = False | |
| ) -> List[Dict]: | |
| """Annotate a list of texts with NER labels. | |
| Args: | |
| data: List of texts to annotate | |
| labels: List of entity labels to detect | |
| prompt: Optional prompt or list of prompts to use | |
| threshold: Confidence threshold for entity detection | |
| nested_ner: Whether to allow nested entities | |
| Returns: | |
| List of annotated examples | |
| """ | |
| self.stat["total"] = len(data) | |
| self.stat["current"] = -1 | |
| # Process texts in batches | |
| processed_data = [] | |
| batch_size = 8 # Reduced batch size to prevent OOM errors | |
| for i in range(0, len(data), batch_size): | |
| batch_texts = data[i:i + batch_size] | |
| batch_with_prompts = [] | |
| # Add prompts to batch texts | |
| for text in batch_texts: | |
| if isinstance(prompt, list): | |
| prompt_text = random.choice(prompt) | |
| else: | |
| prompt_text = prompt | |
| text_with_prompt = f"{prompt_text}\n{text}" if prompt_text else text | |
| batch_with_prompts.append(text_with_prompt) | |
| # Process batch | |
| batch_results = self._batch_annotate_text( | |
| batch_with_prompts, | |
| labels, | |
| threshold, | |
| nested_ner | |
| ) | |
| processed_data.extend(batch_results) | |
| # Clear CUDA cache after each batch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Update progress | |
| self.stat["current"] = min(i + batch_size, len(data)) | |
| self.annotated_data = processed_data | |
| return self.annotated_data | |
| def _batch_annotate_text( | |
| self, | |
| texts: List[str], | |
| labels: List[str], | |
| threshold: float, | |
| nested_ner: bool | |
| ) -> List[Dict]: | |
| """Annotate multiple texts in batch. | |
| Args: | |
| texts: List of texts to annotate | |
| labels: List of entity labels | |
| threshold: Confidence threshold | |
| nested_ner: Whether to allow nested entities | |
| Returns: | |
| List of annotated examples | |
| """ | |
| batch_entities = self.model.batch_predict_entities( | |
| texts, | |
| labels, | |
| flat_ner=not nested_ner, | |
| threshold=threshold | |
| ) | |
| results = [] | |
| for text, entities in zip(texts, batch_entities): | |
| r = { | |
| "text": text, | |
| "entities": [ | |
| { | |
| "entity": entity["label"], | |
| "word": entity["text"], | |
| "start": entity["start"], | |
| "end": entity["end"], | |
| "score": 0, | |
| } | |
| for entity in entities | |
| ], | |
| } | |
| r["entities"] = self._merge_entities(r["entities"]) | |
| results.append(self._transform_data(r)) | |
| return results | |
| def _merge_entities(self, entities: List[Dict]) -> List[Dict]: | |
| """Merge adjacent entities of the same type. | |
| Args: | |
| entities: List of entity dictionaries | |
| Returns: | |
| List of merged entities | |
| """ | |
| if not entities: | |
| return [] | |
| merged = [] | |
| current = entities[0] | |
| for next_entity in entities[1:]: | |
| if (next_entity['entity'] == current['entity'] and | |
| (next_entity['start'] == current['end'] + 1 or | |
| next_entity['start'] == current['end'])): | |
| current['word'] += ' ' + next_entity['word'] | |
| current['end'] = next_entity['end'] | |
| else: | |
| merged.append(current) | |
| current = next_entity | |
| merged.append(current) | |
| return merged | |
| def _transform_data(self, data: Dict) -> Dict: | |
| """Transform raw annotation data into tokenized format. | |
| Args: | |
| data: Raw annotation data | |
| Returns: | |
| Transformed data with tokenized text and NER spans | |
| """ | |
| tokens = tokenize_text(data['text']) | |
| spans = [] | |
| for entity in data['entities']: | |
| entity_tokens = tokenize_text(entity['word']) | |
| entity_length = len(entity_tokens) | |
| # Find the start and end indices of each entity in the tokenized text | |
| for i in range(len(tokens) - entity_length + 1): | |
| if tokens[i:i + entity_length] == entity_tokens: | |
| spans.append([i, i + entity_length - 1, entity['entity']]) | |
| break | |
| return { | |
| "tokenized_text": tokens, | |
| "ner": spans, | |
| "validated": False | |
| } |