Hexa09 commited on
Commit
7e9a7de
·
verified ·
1 Parent(s): d56296c

Upload hexa_colab.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. hexa_colab.ipynb +165 -0
hexa_colab.ipynb ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Hexa TTS - Free Colab Training (15GB GPU Optimized)\n",
8
+ "\n",
9
+ "**Compatibility:** Verified for **Tesla T4 (15GB usable VRAM)**.\\n",
10
+ "**Model Config:** Hexa-Base (~350M params).\\n",
11
+ "\n",
12
+ "## Setup\n",
13
+ "1. Set Runtime to **T4 GPU**.\n",
14
+ "2. Run all cells."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# 1. Install Dependencies\n",
24
+ "!pip install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
25
+ "!pip install -q -U transformers accelerate peft bitsandbytes soundfile phonemizer einops"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# 2. Clone Your Repository\n",
35
+ "import os\n",
36
+ "!git clone https://huggingface.co/Hexa09/hexa-tts-5b\n",
37
+ "root_dir = \"/content/hexa-tts-5b\"\n",
38
+ "os.chdir(root_dir)\n",
39
+ "\n",
40
+ "# Fix paths for Colab\n",
41
+ "import sys\n",
42
+ "sys.path.append(root_dir)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# 3. System Imports\n",
52
+ "import torch\n",
53
+ "import gc\n",
54
+ "from transformers import Trainer, TrainingArguments\n",
55
+ "from src.hf_model import HexaModel, HexaHFConfig\n",
56
+ "from src.config import HexaConfig\n",
57
+ "from src.dataset_clean import HexaDataset, collate_fn\n",
58
+ "from get_data import download_data\n",
59
+ "\n",
60
+ "# Clear RAM\n",
61
+ "gc.collect()\n",
62
+ "torch.cuda.empty_cache()"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "# 4. Generate Synthetic Data\n",
72
+ "if not os.path.exists(\"./data/metadata.csv\"):\n",
73
+ " print(\"Generating synthetic data...\")\n",
74
+ " download_data()"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "# 5. Initialize Model (15GB Safe Config)\n",
84
+ "print(\"Initializing Hexa-Base (350M)...\")\n",
85
+ "\n",
86
+ "# 350M Params = ~700MB VRAM (Weights)\n",
87
+ "# Full Training State = ~5GB VRAM\n",
88
+ "# This leaves ~10GB headroom on a 15GB card.\n",
89
+ "\n",
90
+ "hexa_conf = HexaConfig(\n",
91
+ " dim=1024, # Optimized\n",
92
+ " depth=24, # Optimized\n",
93
+ " heads=16, # Optimized\n",
94
+ " dim_head=64 \n",
95
+ ")\n",
96
+ "hf_config = HexaHFConfig(**hexa_conf.__dict__)\n",
97
+ "\n",
98
+ "# Initialize directly on GPU\n",
99
+ "with torch.device(\"cuda\"):\n",
100
+ " model = HexaModel(hf_config)\n",
101
+ "\n",
102
+ "# Move to FP16 \n",
103
+ "model = model.half()\n",
104
+ "model.gradient_checkpointing_enable()\n",
105
+ "model.enable_input_require_grads()\n",
106
+ "\n",
107
+ "print(f\"Model Ready. Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M\")"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# 6. Training Arguments (Safe Mode)\n",
117
+ "args = TrainingArguments(\n",
118
+ " output_dir=\"./hexa_colab_checkpoints\",\n",
119
+ " per_device_train_batch_size=2, # Reduced to 2 for 15GB safety\n",
120
+ " gradient_accumulation_steps=8, # Increased to maintain effective batch size\n",
121
+ " learning_rate=2e-4,\n",
122
+ " num_train_epochs=3,\n",
123
+ " logging_steps=1,\n",
124
+ " fp16=True, \n",
125
+ " gradient_checkpointing=True, \n",
126
+ " report_to=\"none\",\n",
127
+ " dataloader_num_workers=0\n",
128
+ ")\n",
129
+ "\n",
130
+ "dataset = HexaDataset(\"./data\", hexa_conf)\n",
131
+ "\n",
132
+ "trainer = Trainer(\n",
133
+ " model=model,\n",
134
+ " args=args,\n",
135
+ " train_dataset=dataset,\n",
136
+ " data_collator=collate_fn,\n",
137
+ ")\n",
138
+ "\n",
139
+ "print(\"Starting Training...\")\n",
140
+ "trainer.train()"
141
+ ]
142
+ }
143
+ ],
144
+ "metadata": {
145
+ "kernelspec": {
146
+ "display_name": "Python 3",
147
+ "language": "python",
148
+ "name": "python3"
149
+ },
150
+ "language_info": {
151
+ "codemirror_mode": {
152
+ "name": "ipython",
153
+ "version": 3
154
+ },
155
+ "file_extension": ".py",
156
+ "mimetype": "text/x-python",
157
+ "name": "python",
158
+ "nbconvert_exporter": "python",
159
+ "pygments_lexer": "ipython3",
160
+ "version": "3.10.12"
161
+ }
162
+ },
163
+ "nbformat": 4,
164
+ "nbformat_minor": 5
165
+ }