Spaces:
Sleeping
Sleeping
| import copy | |
| from functools import partial | |
| from typing import Callable, Iterable, List, Optional, Tuple, Union, Dict, Any | |
| import murmurhash | |
| from spacy.language import Language | |
| from spacy.tokens.doc import SetEntsDefault # type: ignore | |
| from spacy.training import Example | |
| from spacy.util import filter_spans | |
| from prodigy.components.db import connect | |
| from prodigy.components.decorators import support_both_streams | |
| from prodigy.components.filters import filter_seen_before | |
| from prodigy.components.preprocess import ( | |
| add_annot_name, | |
| add_tokens, | |
| add_view_id, | |
| make_ner_suggestions, | |
| make_raw_doc, | |
| resolve_labels, | |
| split_sentences, | |
| ) | |
| from prodigy.components.sorters import prefer_uncertain | |
| from prodigy.components.source import GeneratorSource | |
| from prodigy.components.stream import Stream, get_stream, load_noop | |
| from prodigy.core import Arg, recipe | |
| from prodigy.errors import RecipeError | |
| from prodigy.models.matcher import PatternMatcher | |
| from prodigy.models.ner import EntityRecognizerModel, ensure_sentencizer | |
| from prodigy.protocols import ControllerComponentsDict | |
| from prodigy.types import ( | |
| ExistingFilePath, | |
| LabelsType, | |
| SourceType, | |
| StreamType, | |
| TaskType, | |
| ) | |
| from prodigy.util import ( | |
| ANNOTATOR_ID_ATTR, | |
| BINARY_ATTR, | |
| INPUT_HASH_ATTR, | |
| TASK_HASH_ATTR, | |
| combine_models, | |
| copy_nlp, | |
| get_pipe_labels, | |
| log, | |
| msg, | |
| set_hashes, | |
| ) | |
| def modify_spans(document): | |
| # Modify the 'spans' key to be an empty list | |
| document['spans'] = [] | |
| return document | |
| def spans_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: | |
| """Checks if two spans are equal""" | |
| return s1["start"] == s2["start"] and s1["end"] == s2["end"] | |
| def labels_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: | |
| """Checks if two spans have the same label""" | |
| return s1["label"] == s2["label"] | |
| def ensure_span_text(eg: TaskType) -> TaskType: | |
| """Ensure that all spans have a text attribute""" | |
| for span in eg.get("spans", []): | |
| if "text" not in span: | |
| span["text"] = eg["text"][span["start"] : span["end"]] | |
| return eg | |
| def validate_answer(answer: TaskType, *, known_answers_map: Dict[int, TaskType]): | |
| """Validate the answer against the known answers""" | |
| known_answer = known_answers_map.get(answer[INPUT_HASH_ATTR]) | |
| if known_answer is None: | |
| print(f"Skipping validation for answer {answer[INPUT_HASH_ATTR]}, no known answer found to validate against.") | |
| return | |
| known_answer = ensure_span_text(known_answer) | |
| errors = [] | |
| known_spans = known_answer.get("spans", []) | |
| answer_spans = answer.get("spans", []) | |
| explanation_label = known_answer.get("meta", {}).get("explanation_label") | |
| explanation_boundaries = known_answer.get("meta", {}).get( | |
| "explanation_boundaries" | |
| ) | |
| if not explanation_boundaries: | |
| explanation_boundaries = ( | |
| "No explanation boundaries" | |
| ) | |
| if len(known_spans) > len(answer_spans): | |
| errors.append( | |
| "You noted fewer entities than expected for this answer. All mentions must be annotated" | |
| ) | |
| elif len(known_spans) < len(answer_spans): | |
| errors.append( | |
| "You noted more entities than expected for this answer." | |
| ) | |
| if not known_spans: | |
| # For cases where no annotations are expected | |
| errors.append(explanation_label) | |
| for known_span, span in zip(known_spans, answer_spans): | |
| if not labels_equal(known_span, span): | |
| # label error | |
| errors.append(explanation_label) | |
| continue | |
| if not spans_equal(known_span, span): | |
| # boundary error | |
| errors.append(explanation_boundaries) | |
| continue | |
| if len(errors) > 0: | |
| error_msg = "\n".join(errors) | |
| error_msg += "\n\nExpected annotations:" | |
| if known_spans: | |
| expected_spans = [ | |
| f'[{s["text"]}]: {s["label"]}' for s in known_spans | |
| ] | |
| if expected_spans: | |
| error_msg += "\n" | |
| for span_msg in expected_spans: | |
| error_msg += span_msg + "\n" | |
| else: | |
| error_msg += "\n\nNone." | |
| raise ValueError(error_msg) | |
| def manual( | |
| dataset: str, | |
| nlp: Language, | |
| source: SourceType, | |
| loader: Optional[str] = None, | |
| label: Optional[LabelsType] = None, | |
| patterns: Optional[ExistingFilePath] = None, | |
| exclude: List[str] = [], | |
| highlight_chars: bool = False, | |
| ) -> ControllerComponentsDict: | |
| """ | |
| Mark spans by token. Requires only a tokenizer and no entity recognizer, | |
| and doesn't do any active learning. If patterns are provided, their matches | |
| are highlighted in the example, if available. The recipe will present | |
| all examples in order, so even examples without matches are shown. If | |
| character highlighting is enabled, no "tokens" are saved to the database. | |
| """ | |
| log("RECIPE: Starting recipe ner.manual", locals()) | |
| labels = get_pipe_labels(label, nlp.pipe_labels.get("ner", [])) | |
| stream = get_stream( | |
| source, | |
| loader=loader, | |
| rehash=True, | |
| dedup=True, | |
| input_key="text", | |
| is_binary=False, | |
| ) | |
| if patterns is not None: | |
| pattern_matcher = PatternMatcher(nlp, combine_matches=True, all_examples=True) | |
| pattern_matcher = pattern_matcher.from_disk(patterns) | |
| stream.apply(lambda examples: (eg for _, eg in pattern_matcher(examples))) | |
| # Add "tokens" key to the tasks, either with words or characters | |
| stream.apply(lambda examples: (modify_spans(eg) for eg in examples)) | |
| exclude_names = [ds.name for ds in exclude] if exclude is not None else None | |
| known_answers = get_stream( | |
| source, | |
| loader=loader, | |
| rehash=True, | |
| dedup=True, | |
| input_key="text", | |
| is_binary=False, | |
| ) | |
| known_answers_map = {eg[INPUT_HASH_ATTR]: eg for eg in known_answers} | |
| return { | |
| "view_id": "ner_manual", | |
| "dataset": dataset, | |
| "stream": [_ for _ in stream], | |
| "exclude": exclude_names, | |
| "validate_answer": partial(validate_answer, known_answers_map=known_answers_map), | |
| "config": { | |
| "lang": nlp.lang, | |
| "labels": labels, | |
| "exclude_by": "input", | |
| "ner_manual_highlight_chars": highlight_chars, | |
| }, | |
| } | |
| def preprocess_stream( | |
| stream: StreamType, | |
| nlp: Language, | |
| *, | |
| labels: Optional[List[str]], | |
| unsegmented: bool, | |
| set_annotations: bool = True, | |
| ) -> StreamType: | |
| if not unsegmented: | |
| stream = split_sentences(nlp, stream) # type: ignore | |
| stream = add_tokens(nlp, stream) # type: ignore | |
| if set_annotations: | |
| spacy_model = f"{nlp.meta['lang']}_{nlp.meta['name']}" | |
| # Add a 'spans' key to each example, with predicted entities | |
| texts = ((eg["text"], eg) for eg in stream) | |
| for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10): | |
| task = copy.deepcopy(eg) | |
| spans = [] | |
| for ent in doc.ents: | |
| if labels and ent.label_ not in labels: | |
| continue | |
| spans.append(ent) | |
| for span in eg.get("spans", []): | |
| spans.append(doc.char_span(span["start"], span["end"], span["label"])) | |
| spans = filter_spans(spans) | |
| span_dicts = [] | |
| for ent in spans: | |
| span_dicts.append( | |
| { | |
| "token_start": ent.start, | |
| "token_end": ent.end - 1, | |
| "start": ent.start_char, | |
| "end": ent.end_char, | |
| "text": ent.text, | |
| "label": ent.label_, | |
| "source": spacy_model, | |
| "input_hash": eg[INPUT_HASH_ATTR], | |
| } | |
| ) | |
| task["spans"] = span_dicts | |
| task[BINARY_ATTR] = False | |
| task = set_hashes(task) | |
| yield task | |
| else: | |
| yield from stream | |
| def get_ner_labels( | |
| nlp: Language, *, label: Optional[List[str]], component: str = "ner" | |
| ) -> Tuple[List[str], bool]: | |
| model_labels = nlp.pipe_labels.get(component, []) | |
| labels = get_pipe_labels(label, model_labels) | |
| # Check if we're annotating all labels present in the model or a subset | |
| no_missing = len(set(labels).intersection(set(model_labels))) == len(model_labels) | |
| return labels, no_missing | |
| def get_update(nlp: Language, *, no_missing: bool) -> Callable[[List[TaskType]], None]: | |
| def update(answers: List[TaskType]) -> None: | |
| log(f"RECIPE: Updating model with {len(answers)} answers") | |
| examples = [] | |
| for eg in answers: | |
| if eg["answer"] == "accept": | |
| doc = make_raw_doc(nlp, eg) | |
| ref = make_raw_doc(nlp, eg) | |
| spans = [ | |
| doc.char_span(span["start"], span["end"], label=span["label"]) | |
| for span in eg.get("spans", []) | |
| ] | |
| value = SetEntsDefault.outside if no_missing else SetEntsDefault.missing | |
| ref.set_ents(spans, default=value) | |
| examples.append(Example(doc, ref)) | |
| nlp.update(examples) | |
| return update | |