Vikrantyadav11234 commited on
Commit
f35a7b1
·
verified ·
1 Parent(s): da7b0e1

Upload create_inline_tags.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. create_inline_tags.py +170 -0
create_inline_tags.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
+ import fasttext
4
+ from indic_transliteration import sanscript
5
+ from indic_transliteration.sanscript import transliterate
6
+ import re
7
+ from tqdm import tqdm
8
+ from functools import lru_cache
9
+ import os
10
+ import urllib.request
11
+
12
+ # Check if CUDA is available and set the device
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ # Load model and tokenizer for NER
17
+ ner_model_name = "xlm-roberta-large-finetuned-conll03-english"
18
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
19
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
20
+
21
+ # Create NER pipeline
22
+ ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, device=0 if torch.cuda.is_available() else -1, aggregation_strategy="simple")
23
+
24
+ # # Load FastText model
25
+ # fasttext_model_dir = '/home/vikrant-MNMT/myenv/fasttext_model'
26
+ # fasttext_model_path = os.path.join(fasttext_model_dir, 'lid.176.ftz')
27
+ # if not os.path.exists(fasttext_model_path):
28
+ # print("Downloading FastText model...")
29
+ # os.makedirs(fasttext_model_dir, exist_ok=True)
30
+ # urllib.request.urlretrieve("https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz", fasttext_model_path)
31
+
32
+ fasttext_model = fasttext.load_model("/home/vikrant-MNMT/myenv/fasttext_model/lid.176.ftz")
33
+
34
+ @lru_cache(maxsize=10000)
35
+ def extract_entities(sentence):
36
+ entities = ner_pipeline(sentence)
37
+ return tuple((ent['word'], ent['entity_group']) for ent in entities)
38
+
39
+ @lru_cache(maxsize=10000)
40
+ def detect_language(text):
41
+ predictions = fasttext_model.predict(text, k=1)
42
+ return predictions[0][0].split('__label__')[1]
43
+
44
+ @lru_cache(maxsize=10000)
45
+ def transliterate_to_latin(text, lang):
46
+ if lang == 'hi' or lang == 'mr':
47
+ return transliterate(text, sanscript.DEVANAGARI, sanscript.ITRANS)
48
+ elif lang == 'pa':
49
+ return transliterate(text, sanscript.GURMUKHI, sanscript.ITRANS)
50
+ elif lang == 'gu':
51
+ return transliterate(text, sanscript.GUJARATI, sanscript.ITRANS)
52
+ elif lang == 'bn' or lang == 'as': # Bengali and Assamese use the same script
53
+ return transliterate(text, sanscript.BENGALI, sanscript.ITRANS)
54
+ elif lang == 'ur':
55
+ return text # Urdu is already in Latin script in our test cases
56
+ elif lang == 'ml':
57
+ return transliterate(text, sanscript.MALAYALAM, sanscript.ITRANS)
58
+ elif lang == 'ta':
59
+ return transliterate(text, sanscript.TAMIL, sanscript.ITRANS)
60
+ elif lang == 'te':
61
+ return transliterate(text, sanscript.TELUGU, sanscript.ITRANS)
62
+ elif lang == 'kn':
63
+ return transliterate(text, sanscript.KANNADA, sanscript.ITRANS)
64
+ elif lang == 'or':
65
+ return transliterate(text, sanscript.ORIYA, sanscript.ITRANS)
66
+ else:
67
+ return text # Return as is for unsupported languages
68
+
69
+ @lru_cache(maxsize=100000)
70
+ def normalize(text):
71
+ # Remove all non-alphanumeric characters and convert to lowercase
72
+ return re.sub(r'[^a-zA-Z0-9]', '', text.lower())
73
+
74
+ def partial_match(s1, s2, threshold=0.7):
75
+ s1_norm = normalize(s1)
76
+ s2_norm = normalize(s2)
77
+ return (s1_norm in s2_norm or s2_norm in s1_norm) or \
78
+ (len(s1_norm) >= 4 and s1_norm[:4] == s2_norm[:4])
79
+
80
+ def process_pair(source, target):
81
+ source = source.strip()
82
+ target = target.strip()
83
+
84
+ source_lang = detect_language(source)
85
+ target_lang = detect_language(target)
86
+
87
+ # Determine which sentence is English
88
+ if source_lang == 'en':
89
+ en_sentence, other_sentence = source, target
90
+ en_entities, other_entities = extract_entities(source), extract_entities(target)
91
+ other_lang = target_lang
92
+ elif target_lang == 'en':
93
+ en_sentence, other_sentence = target, source
94
+ en_entities, other_entities = extract_entities(target), extract_entities(source)
95
+ other_lang = source_lang
96
+ else:
97
+ return [], [] # If neither is English, return no tags
98
+
99
+ pair_tags_en_other = []
100
+ pair_tags_other_en = []
101
+ for en_word, en_tag in en_entities:
102
+ for other_word, other_tag in other_entities:
103
+ if en_tag == other_tag:
104
+ en_norm = normalize(en_word)
105
+ other_trans = transliterate_to_latin(other_word, other_lang)
106
+ other_norm = normalize(other_trans)
107
+ if partial_match(en_norm, other_norm):
108
+ # Skip if either word is empty
109
+ if en_word.strip() and other_word.strip():
110
+ pair_tags_en_other.append(f"en: {en_word}\t{other_lang}: {other_word}\t{en_tag}")
111
+ pair_tags_other_en.append(f"{other_lang}: {other_word}\ten: {en_word}\t{en_tag}")
112
+
113
+ return pair_tags_en_other, pair_tags_other_en
114
+
115
+ def batch_generator(source_file, target_file, batch_size):
116
+ with open(source_file, 'r', encoding='utf-8') as src, open(target_file, 'r', encoding='utf-8') as tgt:
117
+ source_batch, target_batch = [], []
118
+ for source_line, target_line in zip(src, tgt):
119
+ source_batch.append(source_line)
120
+ target_batch.append(target_line)
121
+ if len(source_batch) == batch_size:
122
+ yield source_batch, target_batch
123
+ source_batch, target_batch = [], []
124
+ if source_batch:
125
+ yield source_batch, target_batch
126
+
127
+ def create_dataset(source_file, target_file, output_file_en_other, output_file_other_en, batch_size=32):
128
+ total_lines = sum(1 for _ in open(source_file, 'r', encoding='utf-8'))
129
+ print(f"Processing {total_lines} lines from {source_file} and {target_file}")
130
+
131
+ total_tags_en_other = 0
132
+ total_tags_other_en = 0
133
+ with open(output_file_en_other, "w", encoding="utf-8") as f_en_other, \
134
+ open(output_file_other_en, "w", encoding="utf-8") as f_other_en:
135
+ for i, (source_batch, target_batch) in enumerate(tqdm(batch_generator(source_file, target_file, batch_size),
136
+ total=total_lines//batch_size)):
137
+ batch_tags_en_other = []
138
+ batch_tags_other_en = []
139
+ for source, target in zip(source_batch, target_batch):
140
+ pair_tags_en_other, pair_tags_other_en = process_pair(source, target)
141
+ batch_tags_en_other.extend(pair_tags_en_other)
142
+ batch_tags_other_en.extend(pair_tags_other_en)
143
+
144
+ if batch_tags_en_other:
145
+ f_en_other.write("\n".join(batch_tags_en_other) + "\n")
146
+ f_en_other.flush() # Ensure data is written to disk
147
+ total_tags_en_other += len(batch_tags_en_other)
148
+
149
+ if batch_tags_other_en:
150
+ f_other_en.write("\n".join(batch_tags_other_en) + "\n")
151
+ f_other_en.flush() # Ensure data is written to disk
152
+ total_tags_other_en += len(batch_tags_other_en)
153
+
154
+ if (i + 1) % 1000 == 0:
155
+ print(f"Processed {(i + 1) * batch_size} lines. Current tag count: {total_tags_en_other} (en-other), {total_tags_other_en} (other-en)")
156
+
157
+ print(f"Inline tags extraction completed. {total_tags_en_other} tags saved to {output_file_en_other}.")
158
+ print(f"Inline tags extraction completed. {total_tags_other_en} tags saved to {output_file_other_en}.")
159
+
160
+ def main():
161
+ source_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.src'
162
+ target_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.tgt'
163
+ output_file_en_other = "/home/vikrant-MNMT/myenv/BPCC/inline_tages/eng_Latn-hin_Deva/inline_tag_1.txt"
164
+ output_file_other_en = "/home/vikrant-MNMT/myenv/BPCC/inline_tages/eng_Latn-hin_Deva/inline_tag_2.txt"
165
+ batch_size = 1000
166
+
167
+ create_dataset(source_file, target_file, output_file_en_other, output_file_other_en, batch_size)
168
+
169
+ if __name__ == "__main__":
170
+ main()