File size: 3,689 Bytes
3b71bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# from transformers import Trainer
# import wandb
# import note_seq

# from utils import token_sequence_to_note_sequence

# # first create a custom trainer to log prediction distribution
# SAMPLE_RATE = 44100


# class CustomTrainer(Trainer):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#     def evaluation_loop(
#         self,
#         dataloader,
#         description,
#         prediction_loss_only=None,
#         ignore_keys=None,
#         metric_key_prefix="eval",
#     ):
#         # call super class method to get the eval outputs
#         eval_output = super().evaluation_loop(
#             dataloader,
#             description,
#             prediction_loss_only,
#             ignore_keys,
#             metric_key_prefix,
#         )

#         # log the prediction distribution using `wandb.Histogram` method.
#         if wandb.run is not None:
#             input_ids = self.tokenizer.encode(
#                 "PIECE_START",
#                 return_tensors="pt",
#             ).cuda()
#             # Generate more tokens.
#             voice1_generated_ids = self.model.generate(
#                 input_ids,
#                 max_new_tokens=512,
#                 do_sample=True,
#                 temperature=0.75,
#                 eos_token_id=self.tokenizer.encode("TRACK_END")[0],
#             )
#             voice2_generated_ids = self.model.generate(
#                 voice1_generated_ids,
#                 max_new_tokens=512,
#                 do_sample=True,
#                 temperature=0.75,
#                 eos_token_id=self.tokenizer.encode("TRACK_END")[0],
#             )
#             voice3_generated_ids = self.model.generate(
#                 voice2_generated_ids,
#                 max_new_tokens=512,
#                 do_sample=True,
#                 temperature=0.75,
#                 eos_token_id=self.tokenizer.encode("TRACK_END")[0],
#             )
#             voice4_generated_ids = self.model.generate(
#                 voice3_generated_ids,
#                 max_new_tokens=512,
#                 do_sample=True,
#                 temperature=0.75,
#                 eos_token_id=self.tokenizer.encode("TRACK_END")[0],
#             )
#             token_sequence = self.tokenizer.decode(voice4_generated_ids[0])
#             note_sequence = token_sequence_to_note_sequence(token_sequence)
#             synth = note_seq.fluidsynth
#             array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
#             int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
#             wandb.log({"Generated_audio": wandb.Audio(int16_data, SAMPLE_RATE)})

#         return eval_output
import torch
from typing import Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initialize the model and tokenizer variables as None
tokenizer = None
model = None


def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Returns the preloaded model and tokenizer. If they haven't been loaded before, loads them.
    Returns:
        tuple: A tuple containing the preloaded model and tokenizer.
    """
    global model, tokenizer
    if model is None or tokenizer is None:
        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load the tokenizer and the model
        tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
        model = AutoModelForCausalLM.from_pretrained(
            "juancopi81/lmd-8bars-2048-epochs40_v4"
        )

        # Move model to device
        model = model.to(device)

    return model, tokenizer