Mistral_Test / basic_inference.py
eesfeg's picture
Add application file
1e639fb
raw
history blame contribute delete
633 Bytes
# basic_inference.py
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load from local directory or Hugging Face
model_path = "./tinyllama" # or "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
)
# Move to device if not using device_map
if device == "cuda":
model = model.to(device)