linzhao-amd commited on
Commit
8907e49
·
verified ·
1 Parent(s): f728acc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -1
README.md CHANGED
@@ -32,6 +32,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
32
  from datasets import load_dataset
33
  from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
34
  from quark.contrib.llm_eval import ppl_eval
 
35
  # Register qwen3_next template
36
  qwen3_next_template = LLMTemplate(
37
  model_type="qwen3_next",
@@ -40,28 +41,34 @@ qwen3_next_template = LLMTemplate(
40
  exclude_layers_name=["lm_head", "*linear_attn.in_proj_ba", "*linear_attn.in_proj_qkvz","*mlp.gate", "*mlp.shared_expert_gate", "*self_attn.k_proj", "*self_attn.q_proj", "*self_attn.v_proj"],
41
  )
42
  LLMTemplate.register_template(qwen3_next_template)
 
43
  # Configuration
44
  ckpt_path = "Qwen/Qwen3-Coder-Next"
45
  output_dir = "amd/Qwen3-Coder-Next-MXFP4"
46
  quant_scheme = "mxfp4"
47
  exclude_layers = ["lm_head", "*linear_attn.in_proj_ba", "*linear_attn.in_proj_qkvz","*mlp.gate", "*mlp.shared_expert_gate", "*self_attn.k_proj", "*self_attn.q_proj", "*self_attn.v_proj"]
 
48
  # Load model
49
  model = AutoModelForCausalLM.from_pretrained(ckpt_path, torch_dtype="auto", device_map="auto")
50
  model.eval()
51
  tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)
52
  processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
 
53
  # Get quant config from template
54
  template = LLMTemplate.get(model.config.model_type)
55
  quant_config = template.get_config(scheme=quant_scheme, exclude_layers=exclude_layers)
 
56
  # Quantize
57
  quantizer = ModelQuantizer(quant_config)
58
  model = quantizer.quantize_model(model)
59
  model = quantizer.freeze(model)
 
60
  # Export hf_format
61
  export_safetensors(model, output_dir, custom_mode="quark")
62
  tokenizer.save_pretrained(output_dir)
63
  processor.save_pretrained(output_dir)
64
- # Evaluate PPL
 
65
  testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
66
  testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
67
  ppl = ppl_eval(model, testenc, model.device)
 
32
  from datasets import load_dataset
33
  from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
34
  from quark.contrib.llm_eval import ppl_eval
35
+
36
  # Register qwen3_next template
37
  qwen3_next_template = LLMTemplate(
38
  model_type="qwen3_next",
 
41
  exclude_layers_name=["lm_head", "*linear_attn.in_proj_ba", "*linear_attn.in_proj_qkvz","*mlp.gate", "*mlp.shared_expert_gate", "*self_attn.k_proj", "*self_attn.q_proj", "*self_attn.v_proj"],
42
  )
43
  LLMTemplate.register_template(qwen3_next_template)
44
+
45
  # Configuration
46
  ckpt_path = "Qwen/Qwen3-Coder-Next"
47
  output_dir = "amd/Qwen3-Coder-Next-MXFP4"
48
  quant_scheme = "mxfp4"
49
  exclude_layers = ["lm_head", "*linear_attn.in_proj_ba", "*linear_attn.in_proj_qkvz","*mlp.gate", "*mlp.shared_expert_gate", "*self_attn.k_proj", "*self_attn.q_proj", "*self_attn.v_proj"]
50
+
51
  # Load model
52
  model = AutoModelForCausalLM.from_pretrained(ckpt_path, torch_dtype="auto", device_map="auto")
53
  model.eval()
54
  tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)
55
  processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
56
+
57
  # Get quant config from template
58
  template = LLMTemplate.get(model.config.model_type)
59
  quant_config = template.get_config(scheme=quant_scheme, exclude_layers=exclude_layers)
60
+
61
  # Quantize
62
  quantizer = ModelQuantizer(quant_config)
63
  model = quantizer.quantize_model(model)
64
  model = quantizer.freeze(model)
65
+
66
  # Export hf_format
67
  export_safetensors(model, output_dir, custom_mode="quark")
68
  tokenizer.save_pretrained(output_dir)
69
  processor.save_pretrained(output_dir)
70
+
71
+ # Evaluate PPL (optional)
72
  testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
73
  testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
74
  ppl = ppl_eval(model, testenc, model.device)