|
|
|
|
|
""" |
|
|
Test script to validate Mistral-7B-Instruct AWQ model response generation |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def test_model(): |
|
|
print("Loading Mistral-7B-Instruct AWQ for testing...") |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-AWQ") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"TheBloke/Mistral-7B-Instruct-v0.2-AWQ", |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
print("✅ AWQ model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ AWQ model failed to load: {e}") |
|
|
print("📦 Falling back to regular Mistral-7B-Instruct-v0.2...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
print("✅ Regular model loaded successfully!") |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
|
|
|
test_messages = [ |
|
|
"I feel sad today", |
|
|
"What should I do?", |
|
|
"Hello" |
|
|
] |
|
|
|
|
|
for i, message in enumerate(test_messages): |
|
|
print(f"\n--- Test {i+1}: '{message}' ---") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": message} |
|
|
] |
|
|
|
|
|
|
|
|
conversation = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
input_ids = tokenizer.encode(conversation, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
chat_history_ids = model.generate( |
|
|
input_ids.to(model.device), |
|
|
max_new_tokens=100, |
|
|
no_repeat_ngram_size=2, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
temperature=0.9, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode( |
|
|
chat_history_ids[:, input_ids.shape[-1]:][0], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
print(f"Raw response: '{response}'") |
|
|
print(f"Response length: {len(response)} characters") |
|
|
|
|
|
if len(response) > 1: |
|
|
print("✅ Good response generated") |
|
|
else: |
|
|
print("⚠️ Short/empty response") |
|
|
|
|
|
print("\n✅ Model testing complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_model() |
|
|
|