lemms commited on
Commit
e60435f
Β·
verified Β·
1 Parent(s): a024114

Add OpenLLM custom tokenizer test script

Browse files
Files changed (1) hide show
  1. openllm_tokenizer_fix.py +55 -0
openllm_tokenizer_fix.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OpenLLM Custom Tokenizer Fix Script
4
+
5
+ This script demonstrates the correct way to load OpenLLM models with their
6
+ custom tokenizer classes using trust_remote_code=True.
7
+
8
+ Author: Louis Chua Bean Chong
9
+ License: GPL-3.0
10
+ """
11
+
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import torch
14
+
15
+ def test_openllm_loading():
16
+ """Test loading OpenLLM model with custom tokenizer."""
17
+
18
+ model_name = "lemms/openllm-small-extended-7k"
19
+
20
+ print("πŸ” Testing OpenLLM Custom Tokenizer Loading")
21
+ print("=" * 50)
22
+ print(f"Model: {model_name}")
23
+ print("Note: OpenLLM uses custom tokenizer classes")
24
+ print()
25
+
26
+ try:
27
+ # Load tokenizer with trust_remote_code for custom classes
28
+ print("πŸ”„ Loading custom tokenizer...")
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ model_name,
31
+ trust_remote_code=True, # CRITICAL for custom tokenizer classes
32
+ use_fast=False # Use slow tokenizer for compatibility
33
+ )
34
+ print(f"βœ… Tokenizer loaded: {type(tokenizer).__name__}")
35
+
36
+ # Load model with trust_remote_code
37
+ print("πŸ”„ Loading model...")
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ torch_dtype=torch.float16,
41
+ trust_remote_code=True # CRITICAL for custom model classes
42
+ )
43
+ print(f"βœ… Model loaded: {type(model).__name__}")
44
+
45
+ print("\nπŸŽ‰ OpenLLM loading successful!")
46
+ print("The key is using trust_remote_code=True for custom classes")
47
+
48
+ return True
49
+
50
+ except Exception as e:
51
+ print(f"❌ Loading failed: {e}")
52
+ return False
53
+
54
+ if __name__ == "__main__":
55
+ test_openllm_loading()