| |
| """ |
| Test script to validate Mistral-7B-Instruct AWQ model response generation |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| def test_model(): |
| print("Loading Mistral-7B-Instruct AWQ for testing...") |
| |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-AWQ") |
| model = AutoModelForCausalLM.from_pretrained( |
| "TheBloke/Mistral-7B-Instruct-v0.2-AWQ", |
| device_map="auto", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
| print("✅ AWQ model loaded successfully!") |
| except Exception as e: |
| print(f"⚠️ AWQ model failed to load: {e}") |
| print("📦 Falling back to regular Mistral-7B-Instruct-v0.2...") |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") |
| model = AutoModelForCausalLM.from_pretrained( |
| "mistralai/Mistral-7B-Instruct-v0.2", |
| device_map="auto", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
| print("✅ Regular model loaded successfully!") |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| print("Model loaded successfully!") |
| |
| |
| test_messages = [ |
| "I feel sad today", |
| "What should I do?", |
| "Hello" |
| ] |
| |
| for i, message in enumerate(test_messages): |
| print(f"\n--- Test {i+1}: '{message}' ---") |
| |
| |
| messages = [ |
| {"role": "user", "content": message} |
| ] |
| |
| |
| conversation = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| input_ids = tokenizer.encode(conversation, return_tensors="pt") |
| |
| |
| with torch.no_grad(): |
| chat_history_ids = model.generate( |
| input_ids.to(model.device), |
| max_new_tokens=100, |
| no_repeat_ngram_size=2, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| temperature=0.9, |
| top_k=50, |
| top_p=0.9, |
| use_cache=True |
| ) |
| |
| |
| response = tokenizer.decode( |
| chat_history_ids[:, input_ids.shape[-1]:][0], |
| skip_special_tokens=True |
| ).strip() |
| |
| print(f"Raw response: '{response}'") |
| print(f"Response length: {len(response)} characters") |
| |
| if len(response) > 1: |
| print("✅ Good response generated") |
| else: |
| print("⚠️ Short/empty response") |
| |
| print("\n✅ Model testing complete!") |
|
|
| if __name__ == "__main__": |
| test_model() |
|
|