File size: 1,561 Bytes
a0e0ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from model import CustomBERTModel
from config import Config
import pandas as pd

def load_data(file_path):
    df = pd.read_csv(file_path, header=None)
    return torch.tensor(df.values, dtype=torch.float32)

def get_embeddings(input_file, output_file):
    config = Config()
    model = CustomBERTModel(config).to(config.device)
    model.load_state_dict(torch.load("bert_mlm_model.pth"))
    model.eval()

    input_data = load_data(input_file)
    dataset = TensorDataset(input_data)
    data_loader = DataLoader(dataset, batch_size=config.batch_size)

    all_embeddings = []

    with torch.no_grad():
        for batch in data_loader:
            inputs = batch[0].to(config.device)
            embeddings = model.get_encoder_output(inputs)
            all_embeddings.append(embeddings.cpu().numpy())

    all_embeddings = np.concatenate(all_embeddings, axis=0)
    print(f"Generated embeddings shape: {all_embeddings.shape}")

    # Save embeddings
    np.save(output_file, all_embeddings)
    print(f"Embeddings saved as {output_file}")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves")
    parser.add_argument("input_file", help="Path to the input CSV file containing growth curves")
    parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)")
    args = parser.parse_args()

    get_embeddings(args.input_file, args.output_file)