Spaces:
Runtime error
Runtime error
File size: 872 Bytes
e43be0f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class ModelTrainer:
"""Class for loading and training a TinyLlama model."""
def __init__(self, model_name):
self.model_name = model_name
self.model, self.tokenizer = self.load_model()
def load_model(self):
"""Load a pre-trained model and tokenizer."""
print(f"Loading model: {self.model_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(self.model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return model, tokenizer
if __name__ == "__main__":
# Load model directly
model_name = "sainoforce/modelv2"
trainer = ModelTrainer(model_name)
print("Model and tokenizer loaded successfully.") |