Healthmodels / src /continual_train.py
theaniketgiri's picture
first
902fa1b
import os
from src.train import train_vae
import pandas as pd
def continual_train(progress_callback=None):
"""
Fine-tune the VAE on new data. Optionally log progress via progress_callback.
"""
# Assume new data is already in data/new_data.csv and preprocessed
if not os.path.exists("data/new_data.csv"):
print("No new data found for continual training.")
return
# Optionally, preprocess new data if needed (skipped for simplicity)
# For now, just retrain on all processed data
print("Loading all processed data for fine-tuning...")
if os.path.exists("data/processed_patient_data.csv"):
feature_df = pd.read_csv("data/processed_patient_data.csv")
# Optionally, append new data
new_df = pd.read_csv("data/new_data.csv")
feature_df = pd.concat([feature_df, new_df], ignore_index=True)
feature_df.to_csv("data/processed_patient_data.csv", index=False)
else:
feature_df = pd.read_csv("data/new_data.csv")
feature_df.to_csv("data/processed_patient_data.csv", index=False)
print(f"Fine-tuning on {feature_df.shape[0]} samples...")
# Call train_vae with progress_callback
train_vae(progress_callback=progress_callback)