KhangTruong's picture
Super-squash branch 'main' using huggingface_hub
6f31f53 verified
from .imports import *
def get_strategy():
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
print('Running on TPU ', tpu.master())
BATCH_SIZE = 100
except ValueError:
tpu = None
# If TPU is not available, use default strategy (CPU/GPU)
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
else:
strategy = tf.distribute.get_strategy()
print('Running on CPU/GPU')
def get_data_json() -> dict[str, list[dict[str, str | int]]]:
fname = './02_NWPU_caption/dataset_nwpu.json'
with open(fname) as f:
data = json.load(f)
return data
def extractor():
def getSentence(label: str):
rg = re.split('[ ,]', re.sub(r'[^\x00-\x7F]+', ' ', label))
return [x for x in rg if x]
raw_data = get_data_json()
processed_data = {
f'02_NWPU_RESISC45/{category}/{dictionary["filename"]}': [
dictionary['raw'],
dictionary['raw_1'],
dictionary['raw_2'],
dictionary['raw_3'],
dictionary['raw_4']
] for category in raw_data for dictionary in raw_data[category]
}
very_processed_data = {key: [getSentence(x) for x in processed_data[key]] for key in processed_data}
return very_processed_data
def store_tokenizer_and_data():
processed_data = extractor()
all_words = {word.lower() for name in processed_data for sentence in processed_data[name] for word in sentence}
all_words = ['', '[begin]', '[end]'] + list(sorted(all_words))
tokenizer = {word: i for i, word in enumerate(all_words)}
tokenized_data = {key: [[tokenizer[word.lower()] for word in ['[begin]'] + sentence + ['[end]']] for sentence in processed_data[key]] for key in processed_data}
with open('tokenizer.json', 'w') as f:
json.dump(tokenizer, f)
with open('data.json', 'w') as f:
json.dump(tokenized_data, f)
def load_tokenizer():
with open('tokenizer.json') as f:
return json.load(f)
def split_data():
with open('data.json') as f:
data = json.load(f)
test_lst = []
train_lst = []
for i, k in enumerate(data):
if i % 5 > 3:
tgt_lst = test_lst
else:
tgt_lst = train_lst
tgt_lst.append((k, data[k]))
with open('train.json', 'w') as f:
json.dump(train_lst, f)
with open('test.json', 'w') as f:
json.dump(test_lst, f)