Spaces:
Sleeping
Sleeping
| import os | |
| import spacy | |
| from huggingface_hub import snapshot_download | |
| from typing import List, Dict, Any | |
| import logging | |
| HF_MODEL_ID = "kn29/my-ner-model" | |
| logger = logging.getLogger(__name__) | |
| # Global variable to store the loaded model | |
| _nlp_model = None | |
| def _initialize_model(model_id: str = None): | |
| """Initialize the NER model""" | |
| global _nlp_model | |
| if _nlp_model is not None: | |
| return _nlp_model | |
| if model_id is None: | |
| model_id = HF_MODEL_ID | |
| try: | |
| logger.info(f"Loading NER model from Hugging Face: {model_id}") | |
| token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") | |
| local_dir = snapshot_download( | |
| repo_id=model_id, | |
| token=token if token else None | |
| ) | |
| _nlp_model = spacy.load(local_dir) | |
| logger.info( | |
| f"Successfully loaded NER model from {model_id} (token={'yes' if token else 'no'})" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load NER model from {model_id}: {str(e)}") | |
| # Fallback to standard English model | |
| try: | |
| logger.info("Falling back to standard English model") | |
| _nlp_model = spacy.load("en_core_web_sm") | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback model also failed: {str(fallback_error)}") | |
| raise Exception(f"No spaCy model available: {str(e)}") | |
| return _nlp_model | |
| def process_text(text: str, model_id: str = None) -> Dict[str, Any]: | |
| """Process text with NER model""" | |
| try: | |
| nlp = _initialize_model(model_id) | |
| if len(text) > 4000000: | |
| logger.info(f"Text too large ({len(text)} chars), processing in chunks") | |
| return _process_large_text(text, nlp) | |
| doc = nlp(text) | |
| entities = [] | |
| entity_counts = {} | |
| for ent in doc.ents: | |
| processed_entities = _process_entity(ent) | |
| for entity_text, entity_label in processed_entities: | |
| entity_info = { | |
| "text": entity_text, | |
| "label": entity_label, | |
| "start": ent.start_char, | |
| "end": ent.end_char | |
| } | |
| entities.append(entity_info) | |
| if entity_label not in entity_counts: | |
| entity_counts[entity_label] = [] | |
| entity_counts[entity_label].append(entity_text) | |
| for label in entity_counts: | |
| unique_entities = list(set(entity_counts[label])) | |
| entity_counts[label] = { | |
| "entities": unique_entities, | |
| "count": len(unique_entities) | |
| } | |
| return { | |
| "entities": entities, | |
| "entity_counts": entity_counts, | |
| "total_entities": len(entities), | |
| "unique_labels": list(entity_counts.keys()) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing text with NER: {str(e)}") | |
| return { | |
| "error": str(e), | |
| "entities": [], | |
| "entity_counts": {}, | |
| "total_entities": 0 | |
| } | |
| def _process_large_text(text: str, nlp, chunk_size: int = 3000000) -> Dict[str, Any]: | |
| """Process large text in chunks""" | |
| chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
| all_entities = [] | |
| all_entity_counts = {} | |
| for i, chunk in enumerate(chunks): | |
| logger.info(f"Processing chunk {i+1}/{len(chunks)}") | |
| try: | |
| doc = nlp(chunk) | |
| for ent in doc.ents: | |
| processed_entities = _process_entity(ent) | |
| for entity_text, entity_label in processed_entities: | |
| entity_info = { | |
| "text": entity_text, | |
| "label": entity_label, | |
| "start": ent.start_char + (i * chunk_size), | |
| "end": ent.end_char + (i * chunk_size) | |
| } | |
| all_entities.append(entity_info) | |
| if entity_label not in all_entity_counts: | |
| all_entity_counts[entity_label] = [] | |
| all_entity_counts[entity_label].append(entity_text) | |
| except Exception as e: | |
| logger.error(f"Error processing chunk {i+1}: {str(e)}") | |
| continue | |
| for label in all_entity_counts: | |
| unique_entities = list(set(all_entity_counts[label])) | |
| all_entity_counts[label] = { | |
| "entities": unique_entities, | |
| "count": len(unique_entities) | |
| } | |
| return { | |
| "entities": all_entities, | |
| "entity_counts": all_entity_counts, | |
| "total_entities": len(all_entities), | |
| "unique_labels": list(all_entity_counts.keys()), | |
| "processed_in_chunks": True, | |
| "num_chunks": len(chunks) | |
| } | |
| def _process_entity(ent) -> List[tuple]: | |
| """Process individual entity, handling special cases""" | |
| if ent.label_ in ["PRECEDENT", "ORG"] and " and " in ent.text: | |
| parts = ent.text.split(" and ") | |
| return [(p.strip(), "ORG") for p in parts] | |
| return [(ent.text, ent.label_)] |