chatbot / validate_fix.py
Deva1211's picture
Resolving issues
fd5eb19
#!/usr/bin/env python3
"""
Quick validation for the specific errors from the previous log
"""
def test_device_mesh_issue():
"""Test the exact error: No module named 'torch.distributed.device_mesh'"""
print("πŸ” Testing device_mesh issue...")
try:
# This was the failing import chain
from accelerate.parallelism_config import ParallelismConfig
print("βœ… accelerate.parallelism_config: OK (device_mesh not required)")
return True
except ImportError as e:
if "device_mesh" in str(e):
print(f"❌ device_mesh still required: {e}")
return False
else:
print(f"⚠️ Other import issue: {e}")
return True
def test_transformers_generation():
"""Test transformers.generation.utils import"""
print("πŸ” Testing transformers generation utils...")
try:
from transformers.generation import GenerationConfig, GenerationMixin
print("βœ… transformers.generation: OK")
return True
except ImportError as e:
print(f"❌ transformers.generation failed: {e}")
return False
def test_mistral_model_import():
"""Test the specific mistral model import that failed"""
print("πŸ” Testing mistral model import...")
try:
from transformers.models.mistral.modeling_mistral import MistralForCausalLM
print("βœ… MistralForCausalLM: OK")
return True
except ImportError as e:
if "device_mesh" in str(e):
print(f"❌ Mistral still needs device_mesh: {e}")
return False
else:
print(f"⚠️ Mistral other issue: {e}")
return True
def test_tokenizer_enum_issue():
"""Test the tokenizer enum issue"""
print("πŸ” Testing tokenizer enum compatibility...")
try:
from transformers import AutoTokenizer
# Try to create a tokenizer that had enum issues
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
print("βœ… DialoGPT tokenizer: No enum issues")
return True
except Exception as e:
if "enum" in str(e).lower() or "variant" in str(e).lower():
print(f"❌ Tokenizer enum issue persists: {e}")
return False
else:
print(f"⚠️ Tokenizer other issue: {e}")
return True
def main():
print("🚨 Validation: Previous Error Conditions")
print("=" * 50)
tests = [
("Device Mesh Issue", test_device_mesh_issue),
("Transformers Generation", test_transformers_generation),
("Mistral Model Import", test_mistral_model_import),
("Tokenizer Enum Issue", test_tokenizer_enum_issue)
]
results = []
for name, test_func in tests:
print(f"\nπŸ§ͺ {name}:")
try:
result = test_func()
results.append(result)
except Exception as e:
print(f"❌ Test crashed: {e}")
results.append(False)
print("\n" + "=" * 50)
passed = sum(results)
total = len(results)
if passed == total:
print("βœ… ALL TESTS PASSED - Previous errors should be resolved!")
else:
print(f"⚠️ {passed}/{total} tests passed - Some issues may persist")
print(f"Success rate: {passed}/{total} ({100*passed/total:.1f}%)")
if __name__ == "__main__":
main()