File size: 872 Bytes
e43be0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class ModelTrainer:
    """Class for loading and training a TinyLlama model."""
    def __init__(self, model_name):
        self.model_name = model_name
        self.model, self.tokenizer = self.load_model()
    
    def load_model(self):
        """Load a pre-trained model and tokenizer."""
        print(f"Loading model: {self.model_name}")

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = AutoModelForCausalLM.from_pretrained(self.model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        return model, tokenizer

if __name__ == "__main__":
    # Load model directly

    model_name = "sainoforce/modelv2"
    trainer = ModelTrainer(model_name)
    print("Model and tokenizer loaded successfully.")