#!/usr/bin/env python3 """ Quick validation for the specific errors from the previous log """ def test_device_mesh_issue(): """Test the exact error: No module named 'torch.distributed.device_mesh'""" print("๐Ÿ” Testing device_mesh issue...") try: # This was the failing import chain from accelerate.parallelism_config import ParallelismConfig print("โœ… accelerate.parallelism_config: OK (device_mesh not required)") return True except ImportError as e: if "device_mesh" in str(e): print(f"โŒ device_mesh still required: {e}") return False else: print(f"โš ๏ธ Other import issue: {e}") return True def test_transformers_generation(): """Test transformers.generation.utils import""" print("๐Ÿ” Testing transformers generation utils...") try: from transformers.generation import GenerationConfig, GenerationMixin print("โœ… transformers.generation: OK") return True except ImportError as e: print(f"โŒ transformers.generation failed: {e}") return False def test_mistral_model_import(): """Test the specific mistral model import that failed""" print("๐Ÿ” Testing mistral model import...") try: from transformers.models.mistral.modeling_mistral import MistralForCausalLM print("โœ… MistralForCausalLM: OK") return True except ImportError as e: if "device_mesh" in str(e): print(f"โŒ Mistral still needs device_mesh: {e}") return False else: print(f"โš ๏ธ Mistral other issue: {e}") return True def test_tokenizer_enum_issue(): """Test the tokenizer enum issue""" print("๐Ÿ” Testing tokenizer enum compatibility...") try: from transformers import AutoTokenizer # Try to create a tokenizer that had enum issues tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") print("โœ… DialoGPT tokenizer: No enum issues") return True except Exception as e: if "enum" in str(e).lower() or "variant" in str(e).lower(): print(f"โŒ Tokenizer enum issue persists: {e}") return False else: print(f"โš ๏ธ Tokenizer other issue: {e}") return True def main(): print("๐Ÿšจ Validation: Previous Error Conditions") print("=" * 50) tests = [ ("Device Mesh Issue", test_device_mesh_issue), ("Transformers Generation", test_transformers_generation), ("Mistral Model Import", test_mistral_model_import), ("Tokenizer Enum Issue", test_tokenizer_enum_issue) ] results = [] for name, test_func in tests: print(f"\n๐Ÿงช {name}:") try: result = test_func() results.append(result) except Exception as e: print(f"โŒ Test crashed: {e}") results.append(False) print("\n" + "=" * 50) passed = sum(results) total = len(results) if passed == total: print("โœ… ALL TESTS PASSED - Previous errors should be resolved!") else: print(f"โš ๏ธ {passed}/{total} tests passed - Some issues may persist") print(f"Success rate: {passed}/{total} ({100*passed/total:.1f}%)") if __name__ == "__main__": main()