Spaces:
Runtime error
Runtime error
| import polars as pl | |
| import tensorflow as tf | |
| from sklearn.model_selection import train_test_split | |
| from src import utils, make_dataset, model | |
| import config | |
| import pickle | |
| from tensorflow.keras.layers import TextVectorization | |
| import os | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| from sklearn.utils import check_random_state | |
| # constants | |
| DATA_DIR = Path(os.getcwd()) / 'dataset' | |
| DATA_PATH = DATA_DIR / 'preprocessed_df.csv' | |
| MODEL_PATH = Path(config.MODEL_DIR) / config.MODEL_FILENAME | |
| VECTORIZER_PATH = Path(config.MODEL_DIR) / config.TEXT_VECTOR_FILENAME | |
| COUNTER_PATH = Path(config.MODEL_DIR) / config.COUNTER_NAME | |
| def set_global_seed(seed): | |
| np.random.seed(seed) | |
| tf.random.set_seed(seed) | |
| global random_state | |
| random_state = check_random_state(seed) | |
| def read_data(DATA_PATH, train_size:float = 1.0): | |
| logging.info('Reading data...') | |
| df = pl.read_csv(DATA_PATH) | |
| sample_rate = int(df.shape[0] * train_size) | |
| df = df.sample(sample_rate, seed=config.SEED) | |
| logging.info(f'Data shape after sampling: {df.shape}') | |
| return df | |
| def main(): | |
| # Call the function to set the seeds | |
| set_global_seed(config.SEED) | |
| utils.configure_logging(config.LOG_DIR, "training_log.txt", log_level=logging.INFO) | |
| df = read_data(DATA_PATH, config.TRAIN_SIZE) | |
| logging.info(f'GPU count: {len(tf.config.list_physical_devices("GPU"))}') | |
| counter = utils.load_counter(COUNTER_PATH) | |
| # Text vectorization | |
| logging.info('Text Vectorizer loading ...') | |
| text_vectorizer = TextVectorization(max_tokens=config.MAX_TOKEN, standardize='lower_and_strip_punctuation', | |
| split='whitespace', | |
| ngrams= None , | |
| output_mode='int', | |
| output_sequence_length=config.OUTPUT_SEQUENCE_LENGTH, | |
| pad_to_max_tokens=True, | |
| vocabulary = list(counter.keys())[:config.MAX_TOKEN-2]) | |
| logging.info(f"text vectorizer vocab size: {text_vectorizer.vocabulary_size()}") | |
| # Create datasets | |
| logging.info('Preparing dataset...') | |
| xtrain, xtest, ytrain, ytest = train_test_split(df.select('review'), df.select('polarity'), test_size=config.TEST_SIZE, random_state=config.SEED, stratify=df['polarity']) | |
| del(df) | |
| train_len = xtrain.shape[0]//config.BATCH_SIZE | |
| test_len = xtest.shape[0]//config.BATCH_SIZE | |
| logging.info('Preparing train dataset...') | |
| train_dataset = make_dataset.create_datasets(xtrain, ytrain, text_vectorizer, batch_size=config.BATCH_SIZE, shuffle=False) | |
| del(xtrain, ytrain) | |
| logging.info('Preparing test dataset...') | |
| test_dataset = make_dataset.create_datasets(xtest, ytest, text_vectorizer, batch_size=config.BATCH_SIZE, shuffle=False) | |
| del(xtest, ytest, counter, text_vectorizer ) | |
| logging.info('Model loading...') | |
| # Train LSTM model | |
| lstm_model = model.create_lstm_model(input_shape=(config.OUTPUT_SEQUENCE_LENGTH,), max_tokens=config.MAX_TOKEN, dim=config.DIM) | |
| lstm_model.compile(optimizer=tf.keras.optimizers.Nadam(learning_rate=config.LEARNING_RATE), | |
| loss = tf.keras.losses.BinaryCrossentropy(), | |
| metrics=['Accuracy']) | |
| print(lstm_model.summary()) | |
| # Callbacks | |
| callbacks = [ | |
| tf.keras.callbacks.EarlyStopping(monitor='loss', patience=config.EARLY_STOPPING_PATIENCE, restore_best_weights=True), | |
| tf.keras.callbacks.ModelCheckpoint(monitor='loss', filepath=MODEL_PATH, save_best_only=True) | |
| ] | |
| # Load model weights if exists | |
| try: | |
| lstm_model.load_weights(MODEL_PATH) | |
| logging.info('Model weights loaded!') | |
| except Exception as e: | |
| logging.error(f'Exception occured while loading model weights {e}') | |
| # Training | |
| logging.info('Model training...') | |
| lstm_history = lstm_model.fit(train_dataset, validation_data=test_dataset, epochs=config.EPOCHS, | |
| steps_per_epoch=int(1.0*(train_len / config.EPOCHS)), | |
| validation_steps=int(1.0*(test_len / config.EPOCHS)), | |
| callbacks=callbacks) | |
| logging.info('Training Complete!') | |
| logging.info('Training history:') | |
| logging.info(lstm_history.history) | |
| print(pl.DataFrame(lstm_history.history)) | |
| # Save text vectorizer and LSTM model | |
| logging.info('Saving Model') | |
| lstm_model.save(MODEL_PATH, save_format='h5') | |
| logging.info('Done') | |
| if __name__ == "__main__": | |
| main() | |