|
|
import os |
|
|
import json |
|
|
import yaml |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import logging |
|
|
import time |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def test_model_loading(base_model='mistralai/Mistral-7B-Instruct-v0.3', timeout=600): |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
logger.info(f"Attempting to load tokenizer from {base_model}") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
base_model, |
|
|
|
|
|
cache_dir='./model_cache', |
|
|
|
|
|
use_auth_token=False, |
|
|
local_files_only=False, |
|
|
resume_download=True |
|
|
) |
|
|
|
|
|
logger.info(f"Tokenizer loaded. Setting pad token if needed.") |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
logger.info(f"Loading model with extended timeout handling") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model, |
|
|
cache_dir='./model_cache', |
|
|
use_auth_token=False, |
|
|
local_files_only=False, |
|
|
resume_download=True, |
|
|
|
|
|
force_download=False, |
|
|
|
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
input_text = "Hello, how are you today?" |
|
|
inputs = tokenizer( |
|
|
input_text, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
add_special_tokens=True, |
|
|
return_attention_mask=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
max_length=20, |
|
|
num_return_sequences=1, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
logger.info(f"Generation successful: {generated_text}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Detailed error during model loading: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
if __name__ == '__main__': |
|
|
test_model_loading() |
|
|
|