Spaces:
Paused
Paused
| """ | |
| Zen Training Space - Unified Training for All Zen Models | |
| Train any Zen model with any dataset combination from HuggingFace | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoProcessor, TrainingArguments, Trainer | |
| from datasets import load_dataset, concatenate_datasets | |
| import json | |
| from typing import List, Dict | |
| # Model configurations | |
| MODELS = { | |
| "Language Models": { | |
| "zen-nano-0.6b": { | |
| "hf_id": "zenlm/zen-nano-0.6b", | |
| "type": "language", | |
| "size": "0.6B", | |
| "context": "32K" | |
| }, | |
| "zen-eco-4b-instruct": { | |
| "hf_id": "zenlm/zen-eco-4b-instruct", | |
| "type": "language", | |
| "size": "4B", | |
| "context": "32K" | |
| }, | |
| "zen-eco-4b-agent": { | |
| "hf_id": "zenlm/zen-eco-4b-agent", | |
| "type": "language", | |
| "size": "4B", | |
| "context": "32K" | |
| }, | |
| "zen-omni-7b": { | |
| "hf_id": "zenlm/zen-omni-7b", | |
| "type": "language", | |
| "size": "7B", | |
| "context": "32K" | |
| }, | |
| "zen-coder-14b": { | |
| "hf_id": "zenlm/zen-coder-14b", | |
| "type": "language", | |
| "size": "14B", | |
| "context": "128K" | |
| }, | |
| "zen-next-32b": { | |
| "hf_id": "zenlm/zen-next-32b", | |
| "type": "language", | |
| "size": "32B", | |
| "context": "32K" | |
| }, | |
| }, | |
| "Vision-Language Models": { | |
| "zen-vl-4b-instruct": { | |
| "hf_id": "zenlm/zen-vl-4b-instruct", | |
| "type": "vision-language", | |
| "size": "4B", | |
| "context": "32K" | |
| }, | |
| "zen-vl-8b-instruct": { | |
| "hf_id": "zenlm/zen-vl-8b-instruct", | |
| "type": "vision-language", | |
| "size": "8B", | |
| "context": "32K" | |
| }, | |
| "zen-vl-30b-instruct": { | |
| "hf_id": "zenlm/zen-vl-30b-instruct", | |
| "type": "vision-language", | |
| "size": "30B", | |
| "context": "32K" | |
| }, | |
| } | |
| } | |
| # Dataset configurations | |
| DATASETS = { | |
| "Agent Training": { | |
| "ADP - AgentTuning OS": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "agenttuning_os", | |
| "size": "~5k samples" | |
| }, | |
| "ADP - AgentTuning KG": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "agenttuning_kg", | |
| "size": "~5k samples" | |
| }, | |
| "ADP - AgentTuning DB": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "agenttuning_db", | |
| "size": "~5k samples" | |
| }, | |
| "ADP - Synatra": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "synatra", | |
| "size": "99k samples" | |
| }, | |
| "ADP - Code Feedback": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "code_feedback", | |
| "size": "66k samples" | |
| }, | |
| "ADP - Go Browse": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "go-browse-wa", | |
| "size": "27k samples" | |
| }, | |
| }, | |
| "Function Calling": { | |
| "xLAM Function Calling 60k": { | |
| "hf_id": "Salesforce/xlam-function-calling-60k", | |
| "config": None, | |
| "size": "60k samples" | |
| }, | |
| }, | |
| "Coding Datasets": { | |
| "Magicoder-OSS-Instruct": { | |
| "hf_id": "ise-uiuc/Magicoder-OSS-Instruct-75K", | |
| "config": None, | |
| "size": "75k code samples" | |
| }, | |
| "CodeFeedback-Filtered": { | |
| "hf_id": "m-a-p/CodeFeedback-Filtered-Instruction", | |
| "config": None, | |
| "size": "157k code samples" | |
| }, | |
| "Evol-Instruct-Code": { | |
| "hf_id": "nickrosh/Evol-Instruct-Code-80k-v1", | |
| "config": None, | |
| "size": "80k evolved code" | |
| }, | |
| }, | |
| "Advanced Agentic": { | |
| "AgentInstruct": { | |
| "hf_id": "microsoft/orca-agentinstruct-1M-v1", | |
| "config": None, | |
| "size": "1M agent samples" | |
| }, | |
| "ToolBench": { | |
| "hf_id": "ToolBench/ToolBench", | |
| "config": None, | |
| "size": "16k tool use" | |
| }, | |
| "WebArena": { | |
| "hf_id": "neulab/agent-data-collection", | |
| "config": "nnetnav-wa", | |
| "size": "~2k web agent" | |
| }, | |
| }, | |
| "Instruction Tuning": { | |
| "Alpaca": { | |
| "hf_id": "tatsu-lab/alpaca", | |
| "config": None, | |
| "size": "52k samples" | |
| }, | |
| "OpenOrca": { | |
| "hf_id": "Open-Orca/OpenOrca", | |
| "config": None, | |
| "size": "4.2M reasoning" | |
| }, | |
| } | |
| } | |
| def train_model( | |
| model_name: str, | |
| selected_datasets: List[str], | |
| max_samples: int, | |
| epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| output_repo: str | |
| ): | |
| """Main training function""" | |
| try: | |
| logs = [] | |
| def log(msg): | |
| print(msg) | |
| logs.append(msg) | |
| yield "\n".join(logs) | |
| yield from log("=" * 80) | |
| yield from log("🧘 ZEN TRAINING SPACE") | |
| yield from log("=" * 80) | |
| yield from log("") | |
| # GPU info | |
| yield from log(f"🎮 GPU Available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| yield from log(f" Device: {torch.cuda.get_device_name(0)}") | |
| yield from log(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") | |
| yield from log("") | |
| # Find model config | |
| # Handle both "Category / ModelName" and "ModelName" formats | |
| if " / " in model_name: | |
| model_short_name = model_name.split(" / ")[1] | |
| else: | |
| model_short_name = model_name | |
| model_config = None | |
| for category in MODELS.values(): | |
| if model_short_name in category: | |
| model_config = category[model_short_name] | |
| break | |
| if not model_config: | |
| yield from log(f"❌ Model {model_short_name} not found") | |
| return | |
| yield from log(f"📦 Loading model: {model_short_name}") | |
| yield from log(f" HF ID: {model_config['hf_id']}") | |
| yield from log(f" Size: {model_config['size']}") | |
| yield from log(f" Type: {model_config['type']}") | |
| # Load model | |
| model = AutoModel.from_pretrained( | |
| model_config['hf_id'], | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| if model_config['type'] == "vision-language": | |
| processor = AutoProcessor.from_pretrained(model_config['hf_id']) | |
| else: | |
| processor = AutoTokenizer.from_pretrained(model_config['hf_id']) | |
| yield from log("✅ Model loaded") | |
| yield from log("") | |
| # Load datasets | |
| yield from log("📚 Loading datasets...") | |
| all_datasets = [] | |
| for dataset_name in selected_datasets: | |
| # Handle both "Category / DatasetName" and "DatasetName" formats | |
| if " / " in dataset_name: | |
| dataset_short_name = dataset_name.split(" / ", 1)[1] | |
| else: | |
| dataset_short_name = dataset_name | |
| # Find dataset config | |
| dataset_config = None | |
| for category in DATASETS.values(): | |
| if dataset_short_name in category: | |
| dataset_config = category[dataset_short_name] | |
| break | |
| if not dataset_config: | |
| yield from log(f"⚠️ Dataset {dataset_short_name} not found, skipping") | |
| continue | |
| yield from log(f" Loading: {dataset_name}") | |
| yield from log(f" HF ID: {dataset_config['hf_id']}") | |
| try: | |
| if dataset_config['config']: | |
| ds = load_dataset( | |
| dataset_config['hf_id'], | |
| dataset_config['config'], | |
| split="train", | |
| streaming=True | |
| ) | |
| else: | |
| ds = load_dataset( | |
| dataset_config['hf_id'], | |
| split="train", | |
| streaming=True | |
| ) | |
| # Take limited samples | |
| samples = [] | |
| for i, example in enumerate(ds): | |
| if i >= max_samples // len(selected_datasets): | |
| break | |
| samples.append(example) | |
| all_datasets.extend(samples) | |
| yield from log(f" ✅ Loaded {len(samples)} samples") | |
| except Exception as e: | |
| yield from log(f" ❌ Error: {e}") | |
| yield from log(f"\n✅ Total samples loaded: {len(all_datasets)}") | |
| yield from log("") | |
| # Training setup | |
| yield from log("⚙️ Training Configuration:") | |
| yield from log(f" Epochs: {epochs}") | |
| yield from log(f" Batch Size: {batch_size}") | |
| yield from log(f" Learning Rate: {learning_rate}") | |
| yield from log(f" Samples: {len(all_datasets)}") | |
| yield from log(f" Output: {output_repo}") | |
| yield from log("") | |
| training_args = TrainingArguments( | |
| output_dir="./training-output", | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| logging_steps=10, | |
| save_steps=100, | |
| bf16=True, | |
| push_to_hub=True, | |
| hub_model_id=output_repo, | |
| report_to="tensorboard", | |
| ) | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=all_datasets if len(all_datasets) > 0 else None, | |
| ) | |
| # Train! | |
| yield from log("🔥 TRAINING STARTED") | |
| yield from log("=" * 80) | |
| result = trainer.train() | |
| yield from log("") | |
| yield from log("=" * 80) | |
| yield from log("✅ TRAINING COMPLETED!") | |
| yield from log("=" * 80) | |
| yield from log(f"📊 Final Loss: {result.training_loss:.4f}") | |
| # Generate model card with dataset info | |
| yield from log("") | |
| yield from log("📝 Generating model card...") | |
| from datetime import datetime | |
| # Build dataset info for model card | |
| dataset_info = [] | |
| dataset_hf_ids = [] | |
| for dataset_name in selected_datasets: | |
| if " / " in dataset_name: | |
| dataset_short_name = dataset_name.split(" / ", 1)[1] | |
| else: | |
| dataset_short_name = dataset_name | |
| for category in DATASETS.values(): | |
| if dataset_short_name in category: | |
| ds_config = category[dataset_short_name] | |
| dataset_info.append(f"- [{dataset_short_name}](https://huggingface.co/datasets/{ds_config['hf_id']}) ({ds_config['size']})") | |
| dataset_hf_ids.append(ds_config['hf_id']) | |
| break | |
| model_card = f"""--- | |
| language: | |
| - en | |
| license: apache-2.0 | |
| tags: | |
| - zen | |
| - vision-language | |
| - function-calling | |
| - agent | |
| base_model: {model_config['hf_id']} | |
| datasets: | |
| {chr(10).join([f"- {hf_id}" for hf_id in dataset_hf_ids])} | |
| --- | |
| # {output_repo.split('/')[-1]} | |
| Fine-tuned from [{model_config['hf_id']}](https://huggingface.co/{model_config['hf_id']}) using the Zen Training Space. | |
| ## Training Details | |
| ### Base Model | |
| - **Model**: {model_short_name} | |
| - **Size**: {model_config['size']} parameters | |
| - **Type**: {model_config['type']} | |
| - **Base HF ID**: [{model_config['hf_id']}](https://huggingface.co/{model_config['hf_id']}) | |
| ### Datasets Used | |
| {chr(10).join(dataset_info)} | |
| ### Training Configuration | |
| - **Total Samples**: {len(all_datasets):,} | |
| - **Epochs**: {epochs} | |
| - **Batch Size**: {batch_size} | |
| - **Learning Rate**: {learning_rate} | |
| - **Precision**: bfloat16 | |
| - **Final Training Loss**: {result.training_loss:.4f} | |
| - **Training Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')} | |
| ### Hardware | |
| - **GPU**: NVIDIA A10G (24GB) | |
| - **Platform**: HuggingFace Spaces | |
| ## Usage | |
| ```python | |
| from transformers import AutoModel, AutoProcessor | |
| model = AutoModel.from_pretrained("{output_repo}", trust_remote_code=True) | |
| processor = AutoProcessor.from_pretrained("{output_repo}") | |
| # Your inference code here | |
| ``` | |
| ## Training Space | |
| This model was trained using the [Zen Training Space](https://huggingface.co/spaces/zeekay/zen-training), | |
| a unified platform for training all Zen AI models. | |
| ## Citation | |
| ```bibtex | |
| @misc{{{output_repo.replace('/', '_').replace('-', '_')}, | |
| author = {{Zen AI}}, | |
| title = {{{output_repo.split('/')[-1]}}}, | |
| year = {{2025}}, | |
| publisher = {{HuggingFace}}, | |
| url = {{https://huggingface.co/{output_repo}}} | |
| }} | |
| ``` | |
| --- | |
| *Trained with ❤️ using [Zen Training Space](https://huggingface.co/spaces/zeekay/zen-training)* | |
| """ | |
| # Save model card | |
| import os | |
| os.makedirs("./training-output", exist_ok=True) | |
| with open("./training-output/README.md", "w") as f: | |
| f.write(model_card) | |
| yield from log("✅ Model card generated") | |
| yield from log(f"☁️ Model uploaded to: {output_repo}") | |
| yield from log("") | |
| yield from log("🎉 SUCCESS!") | |
| except Exception as e: | |
| yield from log(f"\n❌ ERROR: {str(e)}") | |
| import traceback | |
| yield from log(f"\n{traceback.format_exc()}") | |
| # Custom CSS for Inter font and branding | |
| custom_css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| * { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; | |
| } | |
| body, .gradio-container { | |
| background-color: #000 !important; | |
| color: #fff !important; | |
| } | |
| .logo-header { | |
| display: flex; | |
| align-items: center; | |
| gap: 16px; | |
| margin-bottom: 24px; | |
| } | |
| .logo-header img { | |
| height: 48px; | |
| width: auto; | |
| } | |
| h1, h2, h3, h4 { | |
| font-weight: 600 !important; | |
| color: #fff !important; | |
| } | |
| p { | |
| color: #ccc !important; | |
| } | |
| """ | |
| # Build Gradio Interface | |
| with gr.Blocks(title="Zen Training Space", css=custom_css, theme=gr.themes.Soft()) as demo: | |
| with gr.Row(elem_classes="logo-header"): | |
| gr.HTML(""" | |
| <div style="display: flex; align-items: center; gap: 16px;"> | |
| <img src="https://zenlm.org/logo.png" alt="Zen LM" style="height: 48px; width: auto;"> | |
| <div> | |
| <h1 style="margin: 0; font-size: 28px; font-weight: 600;">Zen Training Space</h1> | |
| <p style="margin: 0; color: #666; font-size: 16px;">Unified Training Platform for All Zen Models</p> | |
| </div> | |
| </div> | |
| """) | |
| gr.Markdown(""" | |
| Train any Zen model with any dataset combination from HuggingFace. | |
| All datasets are loaded directly from HF - no local storage needed! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Select Model") | |
| model_choice = gr.Dropdown( | |
| choices=[ | |
| *[f"{cat} / {model}" for cat in MODELS for model in MODELS[cat]] | |
| ], | |
| label="Model", | |
| value="Vision-Language Models / zen-vl-4b-instruct" | |
| ) | |
| gr.Markdown("### 2. Select Datasets") | |
| dataset_choices = gr.CheckboxGroup( | |
| choices=[ | |
| *[f"{cat} / {ds}" for cat in DATASETS for ds in DATASETS[cat]] | |
| ], | |
| label="Datasets", | |
| value=[ | |
| "Agent Training / ADP - Synatra", | |
| "Function Calling / xLAM Function Calling 60k" | |
| ] | |
| ) | |
| gr.Markdown("### 3. Training Config") | |
| max_samples = gr.Slider(100, 100000, value=10000, step=100, label="Max Samples") | |
| epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") | |
| batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size") | |
| learning_rate = gr.Number(value=2e-5, label="Learning Rate") | |
| output_repo = gr.Textbox( | |
| value="zenlm/zen-vl-4b-agent-custom", | |
| label="Output Repository (HuggingFace)" | |
| ) | |
| train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Training Logs") | |
| output = gr.Textbox(label="", lines=35, max_lines=50, show_label=False) | |
| train_btn.click( | |
| train_model, | |
| inputs=[ | |
| model_choice, | |
| dataset_choices, | |
| max_samples, | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| output_repo | |
| ], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 📊 Available Models | |
| - **Language**: nano (0.6B), eco (4B), omni (7B), coder (14B), next (32B) | |
| - **Vision-Language**: zen-vl (4B, 8B, 30B) | |
| ### 📚 Available Datasets | |
| - **Agent Training**: ADP (220k+ trajectories across 15+ configs) | |
| - **Function Calling**: xLAM (60k high-quality examples) | |
| - **Instruction**: Alpaca (52k samples) | |
| ### 💰 Cost Estimates (HF Pro GPU) | |
| - 4B model: $3-5 for 10k samples | |
| - 8B model: $8-12 for 10k samples | |
| - 32B model: $30-50 for 10k samples | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |