| | --- |
| | library_name: transformers |
| | license: apache-2.0 |
| | datasets: |
| | - Universal-NER/Pile-NER-type |
| | - Universal-NER/Pile-NER-definition |
| | language: |
| | - en |
| | base_model: |
| | - google/flan-t5-small |
| | pipeline_tag: text2text-generation |
| | tags: |
| | - named-entity-recognition |
| | - generated_from_trainer |
| | --- |
| | # flan-t5-small-ner |
| |
|
| | This model is a fine-tuned version of [google/flan-t5-small](https://huggingface.co/google/flan-t5-small) |
| | on 200 000 random (text, entity) combinations from the |
| | [Universal-NER/Pile-NER-type](https://huggingface.co/datasets/Universal-NER/Pile-NER-type) and |
| | [Universal-NER/Pile-NER-definition](https://huggingface.co/datasets/Universal-NER/Pile-NER-definition) datasets. |
| |
|
| | - Loss: 0.5393 |
| | - Num Input Tokens Seen: 332318598 |
| |
|
| | ## Model Description |
| |
|
| | flan-t5-small-ner can extract entities of specific types or definitions from text such as person, company, school, technology, and many more. |
| | It builds upon the FLAN-T5 architecture, which has strong performance across natural language processing tasks. |
| |
|
| | Example: |
| |
|
| | ```python |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| | import torch |
| | |
| | model_path = "agentlans/flan-t5-small-ner" |
| | model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu") |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | |
| | def custom_split(s): # Processes the output from the model |
| | parts = s.split("<|sep|>") |
| | if not s.endswith("<|end|>"): |
| | parts = parts[:-1] # If output is truncated, then don't include last item |
| | else: |
| | parts[-1] = parts[-1].replace("<|end|>", "") # Remove the marker tokens |
| | return [p.strip() for p in parts if p.strip()] |
| | |
| | def find_entities(input_text, entity_type): |
| | txt = entity_type + "<|sep|>" + input_text + "<|end|>" # Important: need exact input format |
| | inputs = tokenizer(txt, return_tensors="pt").to(model.device) |
| | outputs = model.generate(**inputs, max_new_tokens=100) |
| | decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return custom_split(decoded) |
| | |
| | # Example usage |
| | input_text = "In the bustling metropolis of New York City, Apple Inc. sponsored a conference where Dr. Elena Rodriguez presented groundbreaking research about neuroscience and AI." |
| | print(find_entities(input_text, "person")) # ['Elena Rodriguez'] |
| | print(find_entities(input_text, "company")) # ['Apple Inc.'] |
| | print(find_entities(input_text, "fruit")) # [] |
| | print(find_entities(input_text, "subject")) # ['neuroscience', 'AI'] |
| | ``` |
| |
|
| | ## Limitations |
| |
|
| | - False positives and negatives are possible. |
| | - May struggle with specialized knowledge or fine distinctions. |
| | - Performance may vary for very short or long texts. |
| | - English language only. |
| | - Consider privacy when processing sensitive text. |
| |
|
| | ## Training Procedure |
| |
|
| | ### Training hyperparameters |
| |
|
| | The following hyperparameters were used during training: |
| | - learning_rate: 5e-05 |
| | - train_batch_size: 8 |
| | - eval_batch_size: 8 |
| | - seed: 42 |
| | - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments |
| | - lr_scheduler_type: linear |
| | - num_epochs: 5.0 |
| |
|
| | ### Training results |
| |
|
| | | Training Loss | Epoch | Step | Validation Loss | Input Tokens Seen | |
| | |:-------------:|:-----:|:-----:|:---------------:|:-----------------:| |
| | | 0.8398 | 1.0 | 19991 | 0.6227 | 66451084 | |
| | | 0.7203 | 2.0 | 39982 | 0.5679 | 132976438 | |
| | | 0.6479 | 3.0 | 59973 | 0.5605 | 199402582 | |
| | | 0.6023 | 4.0 | 79964 | 0.5427 | 265875340 | |
| | | 0.5879 | 5.0 | 99955 | 0.5393 | 332318598 | |
| |
|
| | ## Framework Versions |
| |
|
| | - Transformers: 4.46.3 |
| | - PyTorch: 2.5.1+cu124 |
| | - Datasets: 3.2.0 |
| | - Tokenizers: 0.20.3 |