Spaces:
Sleeping
Sleeping
| """ | |
| Zen Coder Flash - Identity Training Space | |
| Fine-tune GLM-4.7-Flash with Zen identity using LoRA | |
| """ | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from datasets import load_dataset | |
| import os | |
| # Constants | |
| MODEL_ID = "zenlm/zen-coder-flash" | |
| DATASET_ID = "zenlm/zen-coder-flash-identity" | |
| OUTPUT_DIR = "./zen-coder-flash-lora" | |
| # Identity data (embedded for quick training) | |
| IDENTITY_DATA = [ | |
| {"instruction": "Who are you?", "output": "I am Zen Coder Flash, a 31B MoE AI model from the Zen family. I'm built on the GLM-4.7-Flash architecture and optimized for code generation, software development, and agentic reasoning. With 31 billion parameters (only 3B active per token) and 131K context, I represent the flagship balance of capability and efficiency in the Zen coder lineup."}, | |
| {"instruction": "What is your name?", "output": "My name is Zen Coder Flash. I'm the flagship coder model in the Zen family, based on GLM-4.7-Flash's advanced Mixture of Experts architecture with 31 billion parameters."}, | |
| {"instruction": "Are you ChatGPT?", "output": "No, I'm not ChatGPT. I'm Zen Coder Flash, a distinct model from the Zen AI family. While we both assist with coding and general tasks, I'm based on GLM-4.7-Flash's MoE architecture and specifically optimized for code generation, software development, and agentic reasoning."}, | |
| {"instruction": "Are you Claude?", "output": "No, I'm Zen Coder Flash, not Claude. I'm part of the Zen model family, built on GLM-4.7-Flash. My focus is code generation and software development, with MoE architecture enabling efficient frontier performance."}, | |
| {"instruction": "What can you do?", "output": "As Zen Coder Flash, I excel at: code generation across 100+ languages, debugging and code review, software architecture, tool calling and function execution, extended reasoning for complex problems, and technical documentation. My 131K context window handles large codebases."}, | |
| {"instruction": "What's your SWE-bench score?", "output": "I scored 59.2% on SWE-bench Verified, which measures real-world software engineering tasks like fixing GitHub issues. For comparison, Qwen3-30B scores 22% and GPT-OSS-20B scores 34%. This reflects strong practical coding ability."}, | |
| ] | |
| def create_training_data(): | |
| """Create training dataset from identity data.""" | |
| from datasets import Dataset | |
| formatted_data = [] | |
| for item in IDENTITY_DATA: | |
| formatted_data.append({ | |
| "text": f"<|user|>\n{item['instruction']}<|assistant|>\n{item['output']}<|endoftext|>" | |
| }) | |
| return Dataset.from_list(formatted_data) | |
| def train_model( | |
| learning_rate: float = 1e-4, | |
| num_epochs: int = 3, | |
| batch_size: int = 1, | |
| lora_r: int = 8, | |
| lora_alpha: int = 16, | |
| progress=gr.Progress() | |
| ): | |
| """Train the model with LoRA.""" | |
| progress(0, desc="Loading model...") | |
| # Check for GPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device == "cpu": | |
| return "⚠️ No GPU detected. Training requires GPU. Please upgrade to a GPU Space." | |
| # Load model in 4-bit | |
| from transformers import BitsAndBytesConfig | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| progress(0.2, desc="Preparing LoRA...") | |
| # Prepare for training | |
| model = prepare_model_for_kbit_training(model) | |
| # LoRA config | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| progress(0.3, desc="Loading dataset...") | |
| # Create dataset | |
| dataset = create_training_data() | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length", | |
| ) | |
| tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
| progress(0.4, desc="Starting training...") | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| logging_steps=1, | |
| save_steps=50, | |
| fp16=True, | |
| report_to="none", | |
| ) | |
| from transformers import Trainer, DataCollatorForLanguageModeling | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset, | |
| data_collator=data_collator, | |
| ) | |
| # Train | |
| trainer.train() | |
| progress(0.9, desc="Saving adapters...") | |
| # Save | |
| model.save_pretrained(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| progress(1.0, desc="Done!") | |
| return f"✅ Training complete! Adapters saved to {OUTPUT_DIR}" | |
| def test_model(prompt: str): | |
| """Test the model with a prompt.""" | |
| if not os.path.exists(OUTPUT_DIR): | |
| return "⚠️ No trained model found. Please train first." | |
| from peft import PeftModel | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| # Load base + adapters | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, OUTPUT_DIR) | |
| # Generate | |
| formatted = f"<|user|>\n{prompt}<|assistant|>\n" | |
| inputs = tokenizer(formatted, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("<|assistant|>")[-1].strip() | |
| def push_to_hub(repo_id: str): | |
| """Push trained adapters to HuggingFace.""" | |
| if not os.path.exists(OUTPUT_DIR): | |
| return "⚠️ No trained model found. Please train first." | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.upload_folder( | |
| folder_path=OUTPUT_DIR, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| return f"✅ Pushed to https://huggingface.co/{repo_id}" | |
| # Gradio UI | |
| with gr.Blocks(title="Zen Coder Flash Trainer") as demo: | |
| gr.Markdown(""" | |
| # ⚡ Zen Coder Flash - Identity Training | |
| Fine-tune GLM-4.7-Flash with Zen identity using LoRA. | |
| **Model:** [zenlm/zen-coder-flash](https://huggingface.co/zenlm/zen-coder-flash) | |
| """) | |
| with gr.Tab("🎯 Train"): | |
| gr.Markdown("### Training Parameters") | |
| with gr.Row(): | |
| lr = gr.Slider(1e-5, 1e-3, value=1e-4, label="Learning Rate") | |
| epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") | |
| with gr.Row(): | |
| batch = gr.Slider(1, 4, value=1, step=1, label="Batch Size") | |
| lora_r = gr.Slider(4, 64, value=8, step=4, label="LoRA Rank") | |
| train_btn = gr.Button("🚀 Start Training", variant="primary") | |
| train_output = gr.Textbox(label="Status", lines=3) | |
| train_btn.click( | |
| train_model, | |
| inputs=[lr, epochs, batch, lora_r], | |
| outputs=train_output, | |
| ) | |
| with gr.Tab("🧪 Test"): | |
| gr.Markdown("### Test Trained Model") | |
| test_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Who are you?", | |
| lines=2, | |
| ) | |
| test_btn = gr.Button("Generate") | |
| test_output = gr.Textbox(label="Response", lines=5) | |
| test_btn.click(test_model, inputs=test_input, outputs=test_output) | |
| with gr.Tab("📤 Push"): | |
| gr.Markdown("### Push to HuggingFace") | |
| repo_input = gr.Textbox( | |
| label="Repository ID", | |
| value="zenlm/zen-coder-flash-lora", | |
| ) | |
| push_btn = gr.Button("Push to Hub") | |
| push_output = gr.Textbox(label="Status") | |
| push_btn.click(push_to_hub, inputs=repo_input, outputs=push_output) | |
| if __name__ == "__main__": | |
| demo.launch() | |