|
|
import json |
|
|
import random |
|
|
random.seed(42) |
|
|
from functools import lru_cache |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
from pyvi.ViTokenizer import tokenize |
|
|
import re |
|
|
from collections import Counter |
|
|
import copy |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def find_similar_entity_with_simcse(entity_type, entity_value, intent): |
|
|
possible_entities = intent_entities[intent].get(entity_type, []) |
|
|
possible_entities = [e for e in possible_entities if e != entity_value] |
|
|
|
|
|
if not possible_entities: |
|
|
return entity_value |
|
|
|
|
|
source_tokenizer = tokenize(entity_value) |
|
|
embed_source = model.encode(source_tokenizer) |
|
|
|
|
|
similarities = [ |
|
|
util.pytorch_cos_sim(embed_source, entity_to_embedding[possible_entity]).item() |
|
|
for possible_entity in possible_entities |
|
|
] |
|
|
|
|
|
return possible_entities[similarities.index(max(similarities))] |
|
|
|
|
|
def regenerate_entities_from_annotation(annotation): |
|
|
""" |
|
|
Extract entities and their types from the sentence_annotation. |
|
|
Return a list of entities with their type and filler. |
|
|
""" |
|
|
entity_pattern = re.compile(r'\[ ([^\]]+) : ([^\]]+) \]') |
|
|
entities = [] |
|
|
for match in entity_pattern.findall(annotation): |
|
|
entity_type, filler = match |
|
|
entities.append({ |
|
|
'type': entity_type, |
|
|
'filler': filler |
|
|
}) |
|
|
return entities |
|
|
|
|
|
def augment_and_correct_all(sample, augmentation_ratio=0.5): |
|
|
""" |
|
|
Augment based on the sentence_annotation, reconstruct the sentence, |
|
|
and regenerate the entities list. |
|
|
""" |
|
|
if random.random() > augmentation_ratio: |
|
|
return None |
|
|
|
|
|
augmented_sample = copy.deepcopy(sample) |
|
|
original_annotation = sample['sentence_annotation'] |
|
|
annotation = original_annotation |
|
|
|
|
|
num_entities_to_replace = random.randint(1, len(sample['entities'])) |
|
|
for entity in reversed(sample['entities'][:num_entities_to_replace]): |
|
|
similar_entity = find_similar_entity_with_simcse(entity['type'], entity['filler'], sample['intent']) |
|
|
if similar_entity != entity['filler']: |
|
|
annotation = annotation.replace(f"[ {entity['type']} : {entity['filler']} ]", f"[ {entity['type']} : {similar_entity} ]") |
|
|
|
|
|
|
|
|
sentence = re.sub(r'\[ [^\]]+ : ([^\]]+) \]', r'\1', annotation) |
|
|
|
|
|
if sentence == sample['sentence']: |
|
|
return None |
|
|
|
|
|
augmented_sample['sentence'] = sentence |
|
|
augmented_sample['sentence_annotation'] = annotation |
|
|
augmented_sample['entities'] = regenerate_entities_from_annotation(annotation) |
|
|
|
|
|
return augmented_sample |
|
|
|
|
|
|
|
|
def augment_entry_owner(entry, chosen_word): |
|
|
device_filler = next(entity['filler'] for entity in entry['entities'] if entity['type'] == 'device') |
|
|
augmented_sentence = entry['sentence'].replace(device_filler, device_filler + " của " + chosen_word, 1) |
|
|
augmented_annotation = entry['sentence_annotation'].replace(device_filler, device_filler + " của " + chosen_word, 1) |
|
|
|
|
|
|
|
|
for entity in entry['entities']: |
|
|
if entity['type'] == 'device': |
|
|
entity['filler'] = entity['filler'] + " của " + chosen_word |
|
|
break |
|
|
|
|
|
return { |
|
|
'sentence': augmented_sentence, |
|
|
'intent': entry['intent'], |
|
|
'sentence_annotation': augmented_annotation, |
|
|
'entities': regenerate_entities_from_annotation(augmented_annotation), |
|
|
} |
|
|
|
|
|
def augment_entry_num_loc(entry, chosen_word): |
|
|
device_filler = next(entity['filler'] for entity in entry['entities'] if entity['type'] == 'device') |
|
|
|
|
|
|
|
|
augmented_sentence = entry['sentence'].replace(device_filler, device_filler + " " + chosen_word, 1) |
|
|
augmented_annotation = entry['sentence_annotation'].replace(device_filler, device_filler + " " + chosen_word, 1) |
|
|
|
|
|
|
|
|
for entity in entry['entities']: |
|
|
if entity['type'] == 'device': |
|
|
entity['filler'] = entity['filler'] + " " + chosen_word |
|
|
break |
|
|
|
|
|
return { |
|
|
'sentence': augmented_sentence, |
|
|
'intent': entry['intent'], |
|
|
'sentence_annotation': augmented_annotation, |
|
|
'entities': entry['entities'], |
|
|
} |
|
|
|
|
|
with open("raw_data/intent_entities.json", 'r') as f: |
|
|
intent_entities = json.load(f) |
|
|
|
|
|
with open("raw_data/train_final_20230919.jsonl", 'r') as f: |
|
|
training_samples = [json.loads(line) for line in f.readlines()] |
|
|
|
|
|
for sample in training_samples: |
|
|
del sample['id'] |
|
|
del sample['file'] |
|
|
|
|
|
model = SentenceTransformer('VoVanPhuc/sup-SimCSE-VietNamese-phobert-base') |
|
|
|
|
|
|
|
|
all_entities = set() |
|
|
for intent, entities in intent_entities.items(): |
|
|
for entity_type, entity_values in entities.items(): |
|
|
all_entities.update(entity_values) |
|
|
|
|
|
all_entities_tokenized = [tokenize(entity) for entity in all_entities] |
|
|
entity_to_embedding = { |
|
|
entity: model.encode(tokenized_entity) |
|
|
for entity, tokenized_entity in zip(all_entities, all_entities_tokenized) |
|
|
} |
|
|
|
|
|
intent_counts = Counter([sample['intent'] for sample in training_samples]) |
|
|
|
|
|
|
|
|
num_intents_to_augment = 4 |
|
|
less_frequent_intents = sorted(intent_counts, key=intent_counts.get)[:num_intents_to_augment] |
|
|
|
|
|
augmented_samples = [] |
|
|
|
|
|
|
|
|
for intent in less_frequent_intents: |
|
|
samples_for_intent = [sample for sample in training_samples if sample['intent'] == intent] |
|
|
|
|
|
for sample in samples_for_intent: |
|
|
augmented_sample = augment_and_correct_all(sample) |
|
|
if augmented_sample: |
|
|
augmented_samples.append(augmented_sample) |
|
|
|
|
|
seen_sentences = set() |
|
|
swapped_samples = [] |
|
|
for sample in augmented_samples: |
|
|
sentence = sample['sentence'] |
|
|
if sentence not in seen_sentences: |
|
|
seen_sentences.add(sentence) |
|
|
swapped_samples.append(sample) |
|
|
|
|
|
|
|
|
|
|
|
words_after_cua_in_locations = [] |
|
|
|
|
|
for entry in training_samples: |
|
|
location_entities = [entity for entity in entry['entities'] if entity['type'] == 'location'] |
|
|
|
|
|
for loc in location_entities: |
|
|
if "của" in loc['filler']: |
|
|
|
|
|
parts = loc['filler'].split("của", 1) |
|
|
if len(parts) > 1: |
|
|
words = parts[1].split() |
|
|
|
|
|
|
|
|
phrase_after_cua = [] |
|
|
for word in words: |
|
|
if "]" in word: |
|
|
phrase_after_cua.append(word.replace("]", "")) |
|
|
break |
|
|
phrase_after_cua.append(word) |
|
|
|
|
|
if phrase_after_cua: |
|
|
words_after_cua_in_locations.append(" ".join(phrase_after_cua)) |
|
|
|
|
|
|
|
|
unique_phrases_after_cua = list(set(words_after_cua_in_locations)) |
|
|
|
|
|
device_entries = [entry for entry in training_samples if any(entity['type'] == 'device' for entity in entry['entities'])] |
|
|
augment_count = len(device_entries) // 15 |
|
|
to_augment = random.sample(device_entries, augment_count) |
|
|
|
|
|
augmented_entries_owner = [augment_entry_owner(entry, random.choice(unique_phrases_after_cua)) for entry in to_augment] |
|
|
|
|
|
|
|
|
|
|
|
so_numbers = [] |
|
|
directions = ["bên trái", "bên phải"] |
|
|
|
|
|
|
|
|
for entry in training_samples: |
|
|
for entity in entry.get('entities', []): |
|
|
if entity['type'] == 'location': |
|
|
filler = entity['filler'] |
|
|
|
|
|
|
|
|
match = re.search(r'số (\d+)', filler) |
|
|
if match: |
|
|
so_numbers.append(match.group(0)) |
|
|
|
|
|
|
|
|
for direction in directions: |
|
|
if direction in filler: |
|
|
so_numbers.append(direction) |
|
|
|
|
|
device_entries = [entry for entry in training_samples if any(entity['type'] == 'device' for entity in entry['entities'])] |
|
|
augment_count = len(device_entries) // 15 |
|
|
to_augment = random.sample(device_entries, augment_count) |
|
|
|
|
|
|
|
|
augmented_entries_num = [augment_entry_num_loc(entry, random.choice(so_numbers)) for entry in to_augment] |
|
|
|
|
|
|
|
|
combined_augmented_samples = training_samples + swapped_samples + augmented_entries_owner + augmented_entries_num |
|
|
|
|
|
combined_augmented_file_path = 'augmented_data.jsonl' |
|
|
with open(combined_augmented_file_path, 'w', encoding='utf-8') as file: |
|
|
for entry in combined_augmented_samples: |
|
|
file.write(json.dumps(entry) + '\n') |
|
|
|
|
|
with open(combined_augmented_file_path, 'r') as f: |
|
|
samples = [json.loads(line) for line in f.readlines()] |
|
|
|
|
|
|
|
|
seen_sentences = set() |
|
|
unique_samples = [sample for sample in samples if sample['sentence'] not in seen_sentences and not seen_sentences.add(sample['sentence'])] |
|
|
|
|
|
|
|
|
cleaned_file_path = combined_augmented_file_path.replace(".jsonl", "_unique.jsonl") |
|
|
with open(cleaned_file_path, 'w') as f: |
|
|
for sample in unique_samples: |
|
|
f.write(json.dumps(sample) + "\n") |
|
|
|
|
|
print(f"Cleaned data saved to: {cleaned_file_path}") |