#!/usr/bin/env python3 """ Test script to validate Mistral-7B-Instruct AWQ model response generation """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer def test_model(): print("Loading Mistral-7B-Instruct AWQ for testing...") # Try AWQ model first, fallback to regular model if needed try: tokenizer = AutoTokenizer.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.2-AWQ") model = AutoModelForCausalLM.from_pretrained( "TheBloke/Mistral-7B-Instruct-v0.2-AWQ", device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) print("āœ… AWQ model loaded successfully!") except Exception as e: print(f"āš ļø AWQ model failed to load: {e}") print("šŸ“¦ Falling back to regular Mistral-7B-Instruct-v0.2...") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.2", device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True ) print("āœ… Regular model loaded successfully!") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded successfully!") # Test conversation test_messages = [ "I feel sad today", "What should I do?", "Hello" ] for i, message in enumerate(test_messages): print(f"\n--- Test {i+1}: '{message}' ---") # Use Mistral chat template format messages = [ {"role": "user", "content": message} ] # Apply chat template conversation = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) input_ids = tokenizer.encode(conversation, return_tensors="pt") # Generate response with settings optimized for Mistral AWQ with torch.no_grad(): chat_history_ids = model.generate( input_ids.to(model.device), max_new_tokens=100, no_repeat_ngram_size=2, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, temperature=0.9, top_k=50, top_p=0.9, use_cache=True ) # Decode response response = tokenizer.decode( chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True ).strip() print(f"Raw response: '{response}'") print(f"Response length: {len(response)} characters") if len(response) > 1: print("āœ… Good response generated") else: print("āš ļø Short/empty response") print("\nāœ… Model testing complete!") if __name__ == "__main__": test_model()