rain1024 commited on
Commit
b85c683
·
0 Parent(s):

Initial commit: Vietnamese dependency parser with Biaffine architecture

Browse files

- UDD1Corpus for loading UDD-1 dataset from HuggingFace
- Training, evaluation, and prediction scripts
- Docker configuration for containerized training
- Support for character LSTM and PhoBERT features

.dockerignore ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git
2
+ .git
3
+ .gitignore
4
+
5
+ # Python
6
+ __pycache__
7
+ *.py[cod]
8
+ *.egg-info
9
+ .eggs
10
+ *.egg
11
+ .venv
12
+ venv
13
+
14
+ # IDE
15
+ .vscode
16
+ .idea
17
+ *.swp
18
+
19
+ # Build artifacts
20
+ dist
21
+ build
22
+ *.so
23
+
24
+ # Models (saved at runtime to network volume)
25
+ models/
26
+ *.pt
27
+ *.bin
28
+
29
+ # Logs
30
+ wandb/
31
+ *.log
32
+
33
+ # Environment
34
+ .env
35
+ .env.*
36
+
37
+ # Docs
38
+ *.md
39
+ !README.md
.gitignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment
2
+ .env
3
+ .env.*
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # Virtual environments
28
+ .venv/
29
+ venv/
30
+ ENV/
31
+ env/
32
+
33
+ # IDE
34
+ .idea/
35
+ .vscode/
36
+ *.swp
37
+ *.swo
38
+ *~
39
+
40
+ # Data and models (large files)
41
+ data/
42
+ models/
43
+ tmp/
44
+ *.pt
45
+ *.bin
46
+ *.safetensors
47
+
48
+ # Logs
49
+ *.log
50
+ wandb/
51
+
52
+ # Jupyter
53
+ .ipynb_checkpoints/
54
+
55
+ # OS
56
+ .DS_Store
57
+ Thumbs.db
CLAUDE.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Bamboo-1 is a Vietnamese dependency parser using the Biaffine architecture (Dozat & Manning, 2017), trained on the UDD-1 dataset from HuggingFace (`undertheseanlp/UDD-1`).
8
+
9
+ ## Commands
10
+
11
+ ### Setup
12
+ ```bash
13
+ uv sync # Install dependencies
14
+ uv sync --extra dev # Include pytest and wandb
15
+ uv sync --extra cloud # Include runpod for cloud training
16
+ ```
17
+
18
+ ### Training
19
+ ```bash
20
+ uv run scripts/train.py # Default training
21
+ uv run scripts/train.py --feat bert --bert vinai/phobert-base # With PhoBERT
22
+ uv run scripts/train.py --wandb --wandb-project bamboo-1 # With W&B logging
23
+ ```
24
+
25
+ ### Evaluation
26
+ ```bash
27
+ uv run scripts/evaluate.py --model models/bamboo-1 # Evaluate on test set
28
+ uv run scripts/evaluate.py --model models/bamboo-1 --detailed # Per-relation breakdown
29
+ ```
30
+
31
+ ### Prediction
32
+ ```bash
33
+ uv run scripts/predict.py --model models/bamboo-1 # Interactive mode
34
+ uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam"
35
+ ```
36
+
37
+ ## Architecture
38
+
39
+ ```
40
+ bamboo-1/
41
+ ├── bamboo1/
42
+ │ └── corpus.py # UDD1Corpus - downloads from HuggingFace, converts to CoNLL-U
43
+ ├── scripts/
44
+ │ ├── train.py # Training entry point (Click CLI)
45
+ │ ├── evaluate.py # UAS/LAS evaluation
46
+ │ └── predict.py # Inference (interactive, file, or single sentence)
47
+ ├── data/ # Auto-generated: CoNLL-U files from UDD-1
48
+ └── models/ # Trained model output
49
+ ```
50
+
51
+ **Key dependencies:**
52
+ - `underthesea[deep]` provides the Biaffine parser implementation (`DependencyParser`, `DependencyParserTrainer`)
53
+ - `datasets` for loading UDD-1 from HuggingFace
54
+ - `click` for CLI argument parsing
55
+
56
+ **Model architecture:**
57
+ - Word + Character LSTM embeddings (or PhoBERT with `--feat bert`)
58
+ - 3-layer BiLSTM encoder (400 hidden units)
59
+ - Biaffine attention for arc and relation prediction
60
+
61
+ ## Key Implementation Details
62
+
63
+ - **UDD1Corpus** (`bamboo1/corpus.py`): Auto-downloads dataset on first use; converts HuggingFace format to CoNLL-U files
64
+ - Scripts use PEP 723 inline dependencies and manual `sys.path` manipulation to import the `bamboo1` module
65
+ - Training hyperparameters are CLI flags (see `--help` for each script)
66
+ - Feature types: `char` (character LSTM), `bert` (PhoBERT), `tag` (POS tags)
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - vi
4
+ license: mit
5
+ tags:
6
+ - dependency-parsing
7
+ - vietnamese
8
+ - nlp
9
+ - biaffine
10
+ datasets:
11
+ - undertheseanlp/UDD-1
12
+ library_name: underthesea
13
+ pipeline_tag: token-classification
14
+ ---
15
+
16
+ # Bamboo-1: Vietnamese Dependency Parser
17
+
18
+ A Vietnamese dependency parser trained on the UDD-1 dataset using the Biaffine architecture.
19
+
20
+ ## Overview
21
+
22
+ Bamboo-1 is a neural dependency parser for Vietnamese that uses:
23
+ - **Architecture**: Biaffine Dependency Parser (Dozat & Manning, 2017)
24
+ - **Dataset**: UDD-1 (Universal Dependency Dataset for Vietnamese)
25
+ - **Features**: Character-level LSTM embeddings
26
+
27
+ ## Installation
28
+
29
+ ```bash
30
+ cd ~/projects/workspace_underthesea/bamboo-1
31
+ uv sync
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ ### Training
37
+
38
+ ```bash
39
+ # Train with default parameters
40
+ uv run scripts/train.py
41
+
42
+ # Train with custom parameters
43
+ uv run scripts/train.py --output models/bamboo-1 --max-epochs 200 --feat char
44
+
45
+ # Train with BERT embeddings
46
+ uv run scripts/train.py --feat bert --bert vinai/phobert-base
47
+
48
+ # Train with Weights & Biases logging
49
+ uv run scripts/train.py --wandb
50
+ ```
51
+
52
+ ### Evaluation
53
+
54
+ ```bash
55
+ # Evaluate trained model
56
+ uv run scripts/evaluate.py --model models/bamboo-1
57
+ ```
58
+
59
+ ### Prediction
60
+
61
+ ```bash
62
+ # Interactive prediction
63
+ uv run scripts/predict.py --model models/bamboo-1
64
+
65
+ # Predict from file
66
+ uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu
67
+ ```
68
+
69
+ ## Dataset
70
+
71
+ The UDD-1 dataset is automatically downloaded from HuggingFace:
72
+ - **Source**: `undertheseanlp/UDD-1`
73
+ - **Train**: 18,282 sentences
74
+ - **Validation**: 859 sentences
75
+ - **Test**: 859 sentences
76
+ - **Format**: Universal Dependencies (CoNLL-U)
77
+
78
+ ## Model Architecture
79
+
80
+ ```
81
+ Input: Vietnamese sentence
82
+
83
+ Word Embeddings + Character LSTM Embeddings
84
+
85
+ BiLSTM Encoder (3 layers, 400 hidden units)
86
+
87
+ Biaffine Attention (Arc + Relation)
88
+
89
+ Output: Dependency tree (head indices + relation labels)
90
+ ```
91
+
92
+ ## Metrics
93
+
94
+ - **UAS (Unlabeled Attachment Score)**: Percentage of tokens with correct head
95
+ - **LAS (Labeled Attachment Score)**: Percentage of tokens with correct head AND relation
96
+
97
+ ## Project Structure
98
+
99
+ ```
100
+ bamboo-1/
101
+ ├── README.md
102
+ ├── requirements.txt
103
+ ├── scripts/
104
+ │ ├── train.py # Training script
105
+ │ ├── evaluate.py # Evaluation script
106
+ │ └── predict.py # Prediction script
107
+ ├── bamboo1/
108
+ │ └── corpus.py # UDD-1 corpus loader
109
+ ├── models/ # Trained models (generated)
110
+ └── data/ # Downloaded dataset (generated)
111
+ ```
112
+
113
+ ## References
114
+
115
+ - [UDD-1 Dataset](https://huggingface.co/datasets/undertheseanlp/UDD-1)
116
+ - [Underthesea NLP Toolkit](https://github.com/undertheseanlp/underthesea)
117
+ - [Deep Biaffine Attention for Neural Dependency Parsing](https://arxiv.org/abs/1611.01734)
118
+
119
+ ## License
120
+
121
+ MIT License
RUNPOD.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training on RunPod
2
+
3
+ Guide for training Bamboo-1 Vietnamese Dependency Parser on RunPod.
4
+
5
+ ## Option 1: Manual Setup (Web UI)
6
+
7
+ ### 1. Create a Pod
8
+
9
+ 1. Go to [RunPod Console](https://runpod.io/console/pods)
10
+ 2. Click "Deploy"
11
+ 3. Select GPU (recommended: RTX A4000 or RTX 3090)
12
+ 4. Choose template: `runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04`
13
+ 5. Set disk size: 20GB+
14
+ 6. **Expose TCP port 22** (để SSH)
15
+ 7. **Thêm SSH public key** vào env `PUBLIC_KEY` (xem mục Best Practices)
16
+ 8. Deploy
17
+
18
+ ### 2. Connect and Train
19
+
20
+ ```bash
21
+ # SSH into the pod or use Web Terminal
22
+
23
+ # Install uv
24
+ curl -LsSf https://astral.sh/uv/install.sh | sh
25
+ source $HOME/.local/bin/env
26
+
27
+ # Clone repo
28
+ git clone https://huggingface.co/undertheseanlp/bamboo-1
29
+ cd bamboo-1
30
+
31
+ # Install dependencies
32
+ uv sync
33
+
34
+ # Train with character embeddings
35
+ uv run scripts/train.py --output models/bamboo-1-char --feat char --max-epochs 100
36
+
37
+ # Or train with BERT (PhoBERT)
38
+ uv run scripts/train.py --output models/bamboo-1-bert --feat bert --max-epochs 50
39
+ ```
40
+
41
+ ### 3. Upload Model
42
+
43
+ ```bash
44
+ # Login to HuggingFace
45
+ huggingface-cli login
46
+
47
+ # Upload trained model
48
+ hf upload undertheseanlp/bamboo-1 models/bamboo-1-char models/bamboo-1-char
49
+ ```
50
+
51
+ ## Option 2: RunPod API
52
+
53
+ ### 1. Setup
54
+
55
+ ```bash
56
+ # Install runpod SDK
57
+ uv pip install runpod
58
+
59
+ # Set API key
60
+ export RUNPOD_API_KEY="your-api-key"
61
+ ```
62
+
63
+ ### 2. Launch Training
64
+
65
+ ```bash
66
+ uv run scripts/runpod_setup.py launch --gpu "NVIDIA RTX A4000"
67
+ ```
68
+
69
+ ### 3. Monitor
70
+
71
+ ```bash
72
+ # Check status
73
+ uv run scripts/runpod_setup.py status
74
+
75
+ # Stop when done
76
+ uv run scripts/runpod_setup.py stop <pod-id>
77
+ ```
78
+
79
+ ## Option 3: One-liner
80
+
81
+ SSH into any RunPod instance and run:
82
+
83
+ ```bash
84
+ curl -LsSf https://astral.sh/uv/install.sh | sh && source $HOME/.local/bin/env && git clone https://huggingface.co/undertheseanlp/bamboo-1 && cd bamboo-1 && uv sync && uv run scripts/train.py --output models/bamboo-1-char
85
+ ```
86
+
87
+ ## GPU Recommendations
88
+
89
+ | GPU | VRAM | Batch Size | Est. Time |
90
+ |-----|------|------------|-----------|
91
+ | RTX 3090 | 24GB | 5000 | ~2-3 hours |
92
+ | RTX A4000 | 16GB | 3000 | ~3-4 hours |
93
+ | RTX A5000 | 24GB | 5000 | ~2-3 hours |
94
+ | A100 | 40GB | 8000 | ~1-2 hours |
95
+
96
+ ## Training with Weights & Biases
97
+
98
+ ```bash
99
+ # Login to W&B
100
+ wandb login
101
+
102
+ # Train with logging
103
+ uv run scripts/train.py --output models/bamboo-1-char --wandb --wandb-project bamboo-1
104
+ ```
105
+
106
+ ## Cost Estimate
107
+
108
+ - RTX A4000: ~$0.20/hour → ~$0.80 for full training
109
+ - RTX 3090: ~$0.30/hour → ~$0.90 for full training
110
+ - A100: ~$1.50/hour → ~$2.25 for full training
111
+
112
+ ## Best Practices
113
+
114
+ ### Luôn bật SSH khi tạo pod
115
+
116
+ Khi tạo pod mới, **bắt buộc** cấu hình SSH để có thể watch logs:
117
+
118
+ 1. **Expose port 22 (TCP)** trong phần "Expose Ports"
119
+ 2. **Thêm SSH Public Key** vào environment variable `PUBLIC_KEY`
120
+
121
+ ```bash
122
+ # Lấy public key từ máy local
123
+ cat ~/.ssh/id_rsa.pub
124
+ ```
125
+
126
+ Nếu không có SSH:
127
+ - Không thể SSH vào pod để xem logs
128
+ - Chỉ có thể dùng Web Terminal (chậm, không tiện)
129
+ - RunPod API không hỗ trợ xem logs trực tiếp
130
+
131
+ ### Kiểm tra GPU utilization
132
+
133
+ Sau khi tạo pod, kiểm tra GPU có đang được sử dụng không:
134
+
135
+ ```python
136
+ import runpod
137
+ pods = runpod.get_pods()
138
+ # Nếu gpuUtilPercent = 0% trong thời gian dài → training chưa chạy hoặc đã xong
139
+ ```
140
+
141
+ Tránh lãng phí tiền khi GPU idle.
bamboo1/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Bamboo-1: Vietnamese Dependency Parser trained on UDD-1."""
2
+
3
+ from bamboo1.corpus import UDD1Corpus
4
+
5
+ __all__ = ["UDD1Corpus"]
6
+ __version__ = "0.1.0"
bamboo1/corpus.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UDD-1 Corpus loader for dependency parsing.
3
+
4
+ This module provides a corpus class that downloads the UDD-1 dataset from
5
+ HuggingFace and converts it to CoNLL format for use with the underthesea
6
+ dependency parser trainer.
7
+ """
8
+
9
+ import os
10
+ from pathlib import Path
11
+
12
+
13
+ class UDD1Corpus:
14
+ """
15
+ Corpus class for the UDD-1 (Universal Dependency Dataset) for Vietnamese.
16
+
17
+ This class downloads the UDD-1 dataset from HuggingFace and converts it to
18
+ CoNLL-U format files that can be used with the underthesea ParserTrainer.
19
+
20
+ Attributes:
21
+ train: Path to the training data file (CoNLL format)
22
+ dev: Path to the development/validation data file (CoNLL format)
23
+ test: Path to the test data file (CoNLL format)
24
+
25
+ Example:
26
+ >>> from bamboo1.corpus import UDD1Corpus
27
+ >>> corpus = UDD1Corpus()
28
+ >>> print(corpus.train) # Path to train.conllu
29
+ """
30
+
31
+ name = "UDD-1"
32
+
33
+ def __init__(self, data_dir: str = None, force_download: bool = False):
34
+ """
35
+ Initialize the UDD-1 corpus.
36
+
37
+ Args:
38
+ data_dir: Directory to store the converted CoNLL files.
39
+ Defaults to ./data/UDD-1
40
+ force_download: If True, re-download and convert even if files exist.
41
+ """
42
+ if data_dir is None:
43
+ data_dir = Path(__file__).parent.parent / "data" / "UDD-1"
44
+ self.data_dir = Path(data_dir)
45
+ self.data_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ self._train = self.data_dir / "train.conllu"
48
+ self._dev = self.data_dir / "dev.conllu"
49
+ self._test = self.data_dir / "test.conllu"
50
+
51
+ if force_download or not self._files_exist():
52
+ self._download_and_convert()
53
+
54
+ def _files_exist(self) -> bool:
55
+ """Check if all required files exist."""
56
+ return self._train.exists() and self._dev.exists() and self._test.exists()
57
+
58
+ def _download_and_convert(self):
59
+ """Download UDD-1 from HuggingFace and convert to CoNLL format."""
60
+ # Lazy import - only needed when downloading
61
+ from datasets import load_dataset
62
+
63
+ print(f"Downloading UDD-1 dataset from HuggingFace...")
64
+ dataset = load_dataset("undertheseanlp/UDD-1")
65
+
66
+ print(f"Converting to CoNLL format...")
67
+ self._convert_split(dataset["train"], self._train)
68
+ self._convert_split(dataset["validation"], self._dev)
69
+ self._convert_split(dataset["test"], self._test)
70
+
71
+ print(f"Dataset saved to {self.data_dir}")
72
+ print(f" Train: {len(dataset['train'])} sentences")
73
+ print(f" Dev: {len(dataset['validation'])} sentences")
74
+ print(f" Test: {len(dataset['test'])} sentences")
75
+
76
+ def _convert_split(self, split, output_path: Path):
77
+ """Convert a dataset split to CoNLL-U format."""
78
+ with open(output_path, "w", encoding="utf-8") as f:
79
+ for item in split:
80
+ sent_id = item.get("sent_id", "")
81
+ text = item.get("text", "")
82
+
83
+ if sent_id:
84
+ f.write(f"# sent_id = {sent_id}\n")
85
+ if text:
86
+ f.write(f"# text = {text}\n")
87
+
88
+ tokens = item["tokens"]
89
+ lemmas = item.get("lemmas", ["_"] * len(tokens))
90
+ upos = item["upos"]
91
+ xpos = item.get("xpos", ["_"] * len(tokens))
92
+ feats = item.get("feats", ["_"] * len(tokens))
93
+ heads = item["head"]
94
+ deprels = item["deprel"]
95
+ deps = item.get("deps", ["_"] * len(tokens))
96
+ misc = item.get("misc", ["_"] * len(tokens))
97
+
98
+ for i in range(len(tokens)):
99
+ token_id = i + 1
100
+ form = tokens[i]
101
+ lemma = lemmas[i] if lemmas[i] else "_"
102
+ upos_tag = upos[i] if upos[i] else "_"
103
+ xpos_tag = xpos[i] if xpos[i] else "_"
104
+ feat = feats[i] if feats[i] else "_"
105
+ head = int(heads[i]) if heads[i] else 0
106
+ deprel = deprels[i] if deprels[i] else "_"
107
+ dep = deps[i] if deps[i] else "_"
108
+ misc_val = misc[i] if misc[i] else "_"
109
+
110
+ line = f"{token_id}\t{form}\t{lemma}\t{upos_tag}\t{xpos_tag}\t{feat}\t{head}\t{deprel}\t{dep}\t{misc_val}"
111
+ f.write(line + "\n")
112
+
113
+ f.write("\n")
114
+
115
+ @property
116
+ def train(self) -> str:
117
+ """Path to training data file."""
118
+ return str(self._train)
119
+
120
+ @property
121
+ def dev(self) -> str:
122
+ """Path to development/validation data file."""
123
+ return str(self._dev)
124
+
125
+ @property
126
+ def test(self) -> str:
127
+ """Path to test data file."""
128
+ return str(self._test)
129
+
130
+ def get_statistics(self) -> dict:
131
+ """Get dataset statistics."""
132
+ # Lazy import - only needed for statistics
133
+ from datasets import load_dataset
134
+
135
+ dataset = load_dataset("undertheseanlp/UDD-1")
136
+
137
+ stats = {
138
+ "train_sentences": len(dataset["train"]),
139
+ "dev_sentences": len(dataset["validation"]),
140
+ "test_sentences": len(dataset["test"]),
141
+ "train_tokens": sum(len(item["tokens"]) for item in dataset["train"]),
142
+ "dev_tokens": sum(len(item["tokens"]) for item in dataset["validation"]),
143
+ "test_tokens": sum(len(item["tokens"]) for item in dataset["test"]),
144
+ }
145
+
146
+ all_upos = set()
147
+ all_deprels = set()
148
+ for split in ["train", "validation", "test"]:
149
+ for item in dataset[split]:
150
+ all_upos.update(item["upos"])
151
+ all_deprels.update(item["deprel"])
152
+
153
+ stats["num_upos_tags"] = len(all_upos)
154
+ stats["num_deprels"] = len(all_deprels)
155
+ stats["upos_tags"] = sorted(all_upos)
156
+ stats["deprels"] = sorted(all_deprels)
157
+
158
+ return stats
docker/Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Bamboo-1 Vietnamese Dependency Parser Training
2
+ # Optimized for RunPod deployment
3
+ #
4
+ # Build:
5
+ # docker build -t bamboo-1:latest -f docker/Dockerfile .
6
+ #
7
+ # Push to Docker Hub:
8
+ # docker tag bamboo-1:latest <username>/bamboo-1:latest
9
+ # docker push <username>/bamboo-1:latest
10
+ #
11
+ # RunPod Usage:
12
+ # - Set image to: <username>/bamboo-1:latest
13
+ # - Network volume mount: /runpod-volume
14
+ # - Models saved to: /runpod-volume/models
15
+ #
16
+ # Training commands:
17
+ # uv run scripts/train.py
18
+ # uv run scripts/train.py --wandb --wandb-project bamboo-1
19
+
20
+ # RunPod optimized base image
21
+ # - PyTorch 2.6.0 + CUDA 12.8.1
22
+ # - Python 3.9-3.13 (default 3.12)
23
+ # - JupyterLab, SSH, NGINX pre-installed
24
+ # - uv package manager included
25
+ FROM runpod/pytorch:1.0.2-cu1281-torch260-ubuntu2204
26
+
27
+ LABEL maintainer="underthesea"
28
+ LABEL description="Bamboo-1 Vietnamese Dependency Parser - RunPod Training"
29
+
30
+ # Environment variables
31
+ ENV PYTHONUNBUFFERED=1
32
+
33
+ # Set working directory
34
+ WORKDIR /workspace/bamboo-1
35
+
36
+ # Copy dependency files first (for Docker layer cache)
37
+ COPY pyproject.toml uv.lock ./
38
+ COPY docker/requirements.txt ./
39
+
40
+ # Install dependencies with uv
41
+ # Only click and tqdm needed - PyTorch in base, data pre-included
42
+ RUN uv pip install --system -r requirements.txt
43
+
44
+ # Copy project source code
45
+ COPY bamboo1/ ./bamboo1/
46
+ COPY scripts/ ./scripts/
47
+
48
+ # Copy pre-processed data (UDD-1 CoNLL-U files, ~22MB)
49
+ # No need for datasets library at runtime
50
+ COPY data/ ./data/
51
+
52
+ # Create symlink for models to persist on RunPod network volume
53
+ RUN mkdir -p /runpod-volume/bamboo-1/models && \
54
+ ln -sf /runpod-volume/bamboo-1/models models
55
+
56
+ # Default command - start training
57
+ CMD ["uv", "run", "scripts/train.py"]
docker/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Docker requirements for training
2
+ # - PyTorch: pre-installed in base image
3
+ # - datasets: not needed, data pre-included in image
4
+ click>=8.0.0
5
+ tqdm>=4.60.0
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "bamboo-1"
3
+ version = "0.1.0"
4
+ description = "Vietnamese Dependency Parser trained on UDD-1 dataset"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "torch>=2.0.0",
9
+ "datasets>=2.14.0",
10
+ "click>=8.0.0",
11
+ "underthesea>=9.2.0",
12
+ "transformers>=5.0.0",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = [
17
+ "pytest>=7.0.0",
18
+ "wandb>=0.15.0",
19
+ ]
20
+ cloud = [
21
+ "runpod>=1.6.0",
22
+ ]
23
+
24
+ [build-system]
25
+ requires = ["hatchling"]
26
+ build-backend = "hatchling.build"
27
+
28
+ [tool.hatch.build.targets.wheel]
29
+ packages = ["bamboo1"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ underthesea[deep]>=6.8.0
2
+ datasets>=2.14.0
3
+ click>=8.0.0
4
+ torch>=2.0.0
5
+ transformers>=4.30.0
scripts/cost_estimate.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = ["runpod", "python-dotenv", "click"]
4
+ # ///
5
+ """
6
+ Cost estimation utilities for cloud GPU training.
7
+
8
+ Usage:
9
+ from cost_estimate import CostTracker
10
+
11
+ tracker = CostTracker(gpu_type="RTX_A4000")
12
+ tracker.start()
13
+ # ... training loop ...
14
+ tracker.update(epoch=1, total_epochs=100)
15
+ tracker.summary()
16
+ """
17
+
18
+ import time
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+
23
+ # GPU pricing per hour (USD) - RunPod on-demand prices
24
+ GPU_PRICES = {
25
+ "RTX_A4000": 0.20,
26
+ "RTX_A5000": 0.28,
27
+ "RTX_3090": 0.22,
28
+ "RTX_4090": 0.44,
29
+ "A40": 0.39,
30
+ "A100_40GB": 1.09,
31
+ "A100_80GB": 1.59,
32
+ "H100": 2.49,
33
+ "CPU": 0.0, # No GPU cost for CPU-only
34
+ }
35
+
36
+
37
+ def detect_cloud_provider() -> str:
38
+ """Detect cloud provider from environment or metadata."""
39
+ import os
40
+
41
+ # Check environment variables first (most reliable)
42
+ if os.getenv("RUNPOD_POD_ID"):
43
+ return "runpod"
44
+ if os.getenv("LINODE_ID") or os.getenv("LINODE_DATACENTER_ID"):
45
+ return "linode"
46
+ if os.getenv("AWS_EXECUTION_ENV") or os.getenv("AWS_REGION"):
47
+ return "aws"
48
+ if os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT"):
49
+ return "gcp"
50
+ if os.getenv("AZURE_CLIENT_ID") or os.getenv("MSI_ENDPOINT"):
51
+ return "azure"
52
+ if os.getenv("LAMBDA_LABS_API_KEY"):
53
+ return "lambda"
54
+ if os.getenv("VAST_CONTAINERLABEL"):
55
+ return "vast"
56
+ if os.getenv("COLAB_GPU"):
57
+ return "colab"
58
+ if os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
59
+ return "kaggle"
60
+
61
+ # Check for cloud-specific metadata endpoints
62
+ try:
63
+ import subprocess
64
+
65
+ # Check Linode metadata (uses same IP but different path)
66
+ result = subprocess.run(
67
+ ["curl", "-s", "-m", "1", "http://169.254.169.254/v1/instance"],
68
+ capture_output=True, timeout=2
69
+ )
70
+ if result.returncode == 0 and b"instance" in result.stdout.lower():
71
+ return "linode"
72
+
73
+ # Check for AWS metadata
74
+ result = subprocess.run(
75
+ ["curl", "-s", "-m", "1", "http://169.254.169.254/latest/meta-data/ami-id"],
76
+ capture_output=True, timeout=2
77
+ )
78
+ if result.returncode == 0 and b"ami-" in result.stdout:
79
+ return "aws"
80
+
81
+ # Check GCP metadata
82
+ result = subprocess.run(
83
+ ["curl", "-s", "-m", "1", "-H", "Metadata-Flavor: Google",
84
+ "http://metadata.google.internal/computeMetadata/v1/"],
85
+ capture_output=True, timeout=2
86
+ )
87
+ if result.returncode == 0 and result.stdout:
88
+ return "gcp"
89
+
90
+ except Exception:
91
+ pass
92
+
93
+ # Check /etc files for cloud hints
94
+ try:
95
+ with open("/etc/hostname", "r") as f:
96
+ hostname = f.read().lower()
97
+ if "linode" in hostname:
98
+ return "linode"
99
+ except Exception:
100
+ pass
101
+
102
+ # Check sys_vendor (most reliable for Linode)
103
+ try:
104
+ with open("/sys/class/dmi/id/sys_vendor", "r") as f:
105
+ vendor = f.read().strip().lower()
106
+ if "linode" in vendor:
107
+ return "linode"
108
+ if "amazon" in vendor:
109
+ return "aws"
110
+ if "google" in vendor:
111
+ return "gcp"
112
+ if "microsoft" in vendor:
113
+ return "azure"
114
+ except Exception:
115
+ pass
116
+
117
+ # Check product_name as fallback
118
+ try:
119
+ import subprocess
120
+ result = subprocess.run(
121
+ ["cat", "/sys/class/dmi/id/product_name"],
122
+ capture_output=True, timeout=2
123
+ )
124
+ if result.returncode == 0:
125
+ product = result.stdout.decode().lower()
126
+ if "linode" in product:
127
+ return "linode"
128
+ if "amazon" in product or "ec2" in product:
129
+ return "aws"
130
+ if "google" in product:
131
+ return "gcp"
132
+ except Exception:
133
+ pass
134
+
135
+ return "local"
136
+
137
+
138
+ @dataclass
139
+ class HardwareInfo:
140
+ """Detected hardware information."""
141
+ device_type: str # "cuda" or "cpu"
142
+ gpu_name: Optional[str] = None
143
+ gpu_memory_gb: Optional[float] = None
144
+ cpu_name: Optional[str] = None
145
+ cpu_cores: Optional[int] = None
146
+ ram_gb: Optional[float] = None
147
+ cloud_provider: str = "local"
148
+
149
+ def get_gpu_type(self) -> str:
150
+ """Map detected GPU to pricing category."""
151
+ if self.device_type == "cpu" or not self.gpu_name:
152
+ return "CPU"
153
+
154
+ name = self.gpu_name.upper()
155
+
156
+ # Match known GPU types
157
+ if "H100" in name:
158
+ return "H100"
159
+ elif "A100" in name:
160
+ if self.gpu_memory_gb and self.gpu_memory_gb > 50:
161
+ return "A100_80GB"
162
+ return "A100_40GB"
163
+ elif "A40" in name:
164
+ return "A40"
165
+ elif "4090" in name:
166
+ return "RTX_4090"
167
+ elif "3090" in name:
168
+ return "RTX_3090"
169
+ elif "A5000" in name:
170
+ return "RTX_A5000"
171
+ elif "A4000" in name:
172
+ return "RTX_A4000"
173
+ else:
174
+ return "RTX_A4000" # Default fallback
175
+
176
+ def to_dict(self) -> dict:
177
+ """Convert to dictionary for logging."""
178
+ return {
179
+ "device_type": self.device_type,
180
+ "gpu_name": self.gpu_name,
181
+ "gpu_memory_gb": self.gpu_memory_gb,
182
+ "cpu_name": self.cpu_name,
183
+ "cpu_cores": self.cpu_cores,
184
+ "ram_gb": self.ram_gb,
185
+ "gpu_type": self.get_gpu_type(),
186
+ "cloud_provider": self.cloud_provider,
187
+ }
188
+
189
+ def __str__(self) -> str:
190
+ provider = f"[{self.cloud_provider}] " if self.cloud_provider != "local" else ""
191
+ if self.device_type == "cuda" and self.gpu_name:
192
+ mem = f" ({self.gpu_memory_gb:.1f}GB)" if self.gpu_memory_gb else ""
193
+ return f"{provider}{self.gpu_name}{mem}"
194
+ else:
195
+ ram = f", {self.ram_gb:.1f}GB RAM" if self.ram_gb else ""
196
+ return f"{provider}CPU: {self.cpu_name or 'Unknown'} ({self.cpu_cores} cores{ram})"
197
+
198
+
199
+ def detect_hardware() -> HardwareInfo:
200
+ """Detect available hardware (GPU/CPU) and cloud provider."""
201
+ import platform
202
+ import os
203
+
204
+ # Detect cloud provider
205
+ cloud_provider = detect_cloud_provider()
206
+
207
+ # Get CPU info
208
+ cpu_name = platform.processor() or "Unknown"
209
+ cpu_cores = os.cpu_count()
210
+
211
+ # Get RAM
212
+ try:
213
+ import subprocess
214
+ if platform.system() == "Linux":
215
+ mem_info = subprocess.check_output(["free", "-b"]).decode()
216
+ ram_bytes = int(mem_info.split("\n")[1].split()[1])
217
+ ram_gb = ram_bytes / (1024**3)
218
+ else:
219
+ ram_gb = None
220
+ except Exception:
221
+ ram_gb = None
222
+
223
+ # Try to detect GPU with torch
224
+ try:
225
+ import torch
226
+ if torch.cuda.is_available():
227
+ gpu_name = torch.cuda.get_device_name(0)
228
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
229
+ return HardwareInfo(
230
+ device_type="cuda",
231
+ gpu_name=gpu_name,
232
+ gpu_memory_gb=gpu_memory_gb,
233
+ cpu_name=cpu_name,
234
+ cpu_cores=cpu_cores,
235
+ ram_gb=ram_gb,
236
+ cloud_provider=cloud_provider,
237
+ )
238
+ except Exception:
239
+ pass
240
+
241
+ return HardwareInfo(
242
+ device_type="cpu",
243
+ cpu_name=cpu_name,
244
+ cpu_cores=cpu_cores,
245
+ ram_gb=ram_gb,
246
+ cloud_provider=cloud_provider,
247
+ )
248
+
249
+
250
+ @dataclass
251
+ class CostTracker:
252
+ """Track training time and estimate costs."""
253
+
254
+ gpu_type: str = "RTX_A4000"
255
+
256
+ def __post_init__(self):
257
+ self.start_time: Optional[float] = None
258
+ self.hourly_rate = GPU_PRICES.get(self.gpu_type, 0.20)
259
+ self.last_report_time: Optional[float] = None
260
+ self.report_interval = 300 # Report every 5 minutes
261
+
262
+ def start(self):
263
+ """Start the cost tracker."""
264
+ self.start_time = time.time()
265
+ self.last_report_time = self.start_time
266
+
267
+ def elapsed_seconds(self) -> float:
268
+ """Get elapsed time in seconds."""
269
+ if self.start_time is None:
270
+ return 0
271
+ return time.time() - self.start_time
272
+
273
+ def elapsed_hours(self) -> float:
274
+ """Get elapsed time in hours."""
275
+ return self.elapsed_seconds() / 3600
276
+
277
+ def current_cost(self) -> float:
278
+ """Get current cost in USD."""
279
+ return self.elapsed_hours() * self.hourly_rate
280
+
281
+ def estimate_total_cost(self, progress: float) -> float:
282
+ """
283
+ Estimate total cost based on current progress.
284
+
285
+ Args:
286
+ progress: Training progress (0.0 to 1.0)
287
+ """
288
+ if progress <= 0:
289
+ return 0
290
+ return self.current_cost() / progress
291
+
292
+ def estimate_remaining_cost(self, progress: float) -> float:
293
+ """Estimate remaining cost."""
294
+ return self.estimate_total_cost(progress) - self.current_cost()
295
+
296
+ def estimate_remaining_time(self, progress: float) -> float:
297
+ """Estimate remaining time in seconds."""
298
+ if progress <= 0:
299
+ return 0
300
+ elapsed = self.elapsed_seconds()
301
+ total_time = elapsed / progress
302
+ return total_time - elapsed
303
+
304
+ def format_time(self, seconds: float) -> str:
305
+ """Format seconds to human readable string."""
306
+ if seconds < 60:
307
+ return f"{seconds:.0f}s"
308
+ elif seconds < 3600:
309
+ mins = seconds / 60
310
+ return f"{mins:.1f}m"
311
+ else:
312
+ hours = seconds / 3600
313
+ return f"{hours:.1f}h"
314
+
315
+ def format_cost(self, cost: float) -> str:
316
+ """Format cost to human readable string."""
317
+ if cost < 0.01:
318
+ return f"${cost:.4f}"
319
+ elif cost < 1:
320
+ return f"${cost:.3f}"
321
+ else:
322
+ return f"${cost:.2f}"
323
+
324
+ def should_report(self) -> bool:
325
+ """Check if it's time to report costs."""
326
+ if self.last_report_time is None:
327
+ return True
328
+ return time.time() - self.last_report_time >= self.report_interval
329
+
330
+ def get_status(self, epoch: int, total_epochs: int) -> str:
331
+ """Get formatted status string with cost info."""
332
+ progress = epoch / total_epochs if total_epochs > 0 else 0
333
+
334
+ current = self.current_cost()
335
+ estimated_total = self.estimate_total_cost(progress)
336
+ remaining_time = self.estimate_remaining_time(progress)
337
+
338
+ return (
339
+ f"Cost: {self.format_cost(current)} | "
340
+ f"Est. total: {self.format_cost(estimated_total)} | "
341
+ f"ETA: {self.format_time(remaining_time)}"
342
+ )
343
+
344
+ def update(self, epoch: int, total_epochs: int, force: bool = False) -> Optional[str]:
345
+ """
346
+ Update and optionally return status if report interval passed.
347
+
348
+ Returns status string if it's time to report, None otherwise.
349
+ """
350
+ if force or self.should_report():
351
+ self.last_report_time = time.time()
352
+ return self.get_status(epoch, total_epochs)
353
+ return None
354
+
355
+ def summary(self, epoch: int, total_epochs: int) -> str:
356
+ """Get final summary."""
357
+ progress = epoch / total_epochs if total_epochs > 0 else 1.0
358
+ elapsed = self.elapsed_seconds()
359
+ cost = self.current_cost()
360
+
361
+ lines = [
362
+ "=" * 50,
363
+ "Cost Summary",
364
+ "=" * 50,
365
+ f" GPU: {self.gpu_type} (${self.hourly_rate}/hr)",
366
+ f" Duration: {self.format_time(elapsed)}",
367
+ f" Total cost: {self.format_cost(cost)}",
368
+ ]
369
+
370
+ if progress < 1.0:
371
+ estimated = self.estimate_total_cost(progress)
372
+ lines.append(f" Est. full training: {self.format_cost(estimated)}")
373
+
374
+ lines.append("=" * 50)
375
+ return "\n".join(lines)
376
+
377
+
378
+ def get_runpod_costs(pod_id: str = None) -> list[dict]:
379
+ """Get cost info from RunPod API using GraphQL for accurate uptime."""
380
+ import os
381
+ import requests
382
+ from dotenv import load_dotenv
383
+
384
+ load_dotenv()
385
+ api_key = os.getenv("RUNPOD_API_KEY")
386
+
387
+ # Use GraphQL for accurate runtime data
388
+ query = """
389
+ query getMyPods {
390
+ myself {
391
+ pods {
392
+ id
393
+ name
394
+ desiredStatus
395
+ costPerHr
396
+ machine { gpuDisplayName }
397
+ runtime {
398
+ uptimeInSeconds
399
+ gpus { gpuUtilPercent memoryUtilPercent }
400
+ }
401
+ }
402
+ }
403
+ }
404
+ """
405
+
406
+ response = requests.post(
407
+ "https://api.runpod.io/graphql",
408
+ headers={"Authorization": f"Bearer {api_key}"},
409
+ json={"query": query}
410
+ )
411
+ data = response.json()
412
+ pods = data.get("data", {}).get("myself", {}).get("pods", [])
413
+
414
+ if pod_id:
415
+ pods = [p for p in pods if p["id"] == pod_id]
416
+
417
+ results = []
418
+ for pod in pods:
419
+ if pod.get("desiredStatus") != "RUNNING":
420
+ continue
421
+
422
+ cost_per_hr = pod.get("costPerHr", 0)
423
+ runtime = pod.get("runtime") or {}
424
+ uptime_seconds = runtime.get("uptimeInSeconds", 0)
425
+ uptime_hours = uptime_seconds / 3600
426
+ current_cost = cost_per_hr * uptime_hours
427
+
428
+ gpus = runtime.get("gpus") or []
429
+ gpu_util = gpus[0].get("gpuUtilPercent", 0) if gpus else 0
430
+ mem_util = gpus[0].get("memoryUtilPercent", 0) if gpus else 0
431
+
432
+ results.append({
433
+ "id": pod["id"],
434
+ "name": pod.get("name", "N/A"),
435
+ "gpu": (pod.get("machine") or {}).get("gpuDisplayName", "N/A"),
436
+ "cost_per_hr": cost_per_hr,
437
+ "uptime_seconds": uptime_seconds,
438
+ "uptime_hours": uptime_hours,
439
+ "current_cost": current_cost,
440
+ "gpu_util": gpu_util,
441
+ "mem_util": mem_util,
442
+ })
443
+
444
+ return results
445
+
446
+
447
+ def print_runpod_report(pods: list[dict], estimate_hours: float = None):
448
+ """Print RunPod cost report."""
449
+ import click
450
+ from datetime import datetime
451
+
452
+ if not pods:
453
+ click.echo("No running pods found.")
454
+ return
455
+
456
+ click.echo(f"\n{'='*60}")
457
+ click.echo(f" RunPod Cost Report - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
458
+ click.echo(f"{'='*60}\n")
459
+
460
+ total_current = 0
461
+ total_per_hr = 0
462
+
463
+ for pod in pods:
464
+ uptime_str = f"{pod['uptime_seconds']:.0f}s" if pod['uptime_seconds'] < 60 else f"{pod['uptime_seconds']/60:.1f}m"
465
+ click.echo(f"Pod: {pod['name']} ({pod['id']})")
466
+ click.echo(f" GPU: {pod['gpu']} @ ${pod['cost_per_hr']:.2f}/hr")
467
+ click.echo(f" Uptime: {uptime_str}")
468
+ click.echo(f" Current Cost: ${pod['current_cost']:.4f}")
469
+ click.echo(f" GPU: {pod['gpu_util']:.0f}% | Mem: {pod['mem_util']:.0f}%")
470
+
471
+ if estimate_hours:
472
+ est_total = pod['cost_per_hr'] * estimate_hours
473
+ remaining_hrs = max(0, estimate_hours - pod['uptime_hours'])
474
+ click.echo(f" Est. Total ({estimate_hours}h): ${est_total:.2f} (remaining: ${pod['cost_per_hr'] * remaining_hrs:.2f})")
475
+
476
+ click.echo()
477
+ total_current += pod['current_cost']
478
+ total_per_hr += pod['cost_per_hr']
479
+
480
+ click.echo(f"{'-'*60}")
481
+ click.echo(f"TOTAL: ${total_current:.4f} (${total_per_hr:.2f}/hr)")
482
+ if estimate_hours:
483
+ click.echo(f"Est. Total ({estimate_hours}h): ${total_per_hr * estimate_hours:.2f}")
484
+ click.echo()
485
+
486
+
487
+ def main():
488
+ """Cost estimation CLI."""
489
+ import click
490
+ import os
491
+
492
+ @click.group()
493
+ def cli():
494
+ """Cost estimation for GPU training."""
495
+ pass
496
+
497
+ @cli.command()
498
+ @click.option("--gpu", default="RTX_A4000", type=click.Choice(list(GPU_PRICES.keys())))
499
+ @click.option("--hours", default=1.0, type=float, help="Estimated training hours")
500
+ def estimate(gpu, hours):
501
+ """Estimate training cost for a GPU."""
502
+ rate = GPU_PRICES[gpu]
503
+ cost = rate * hours
504
+
505
+ click.echo(f"GPU: {gpu}")
506
+ click.echo(f"Rate: ${rate}/hour")
507
+ click.echo(f"Duration: {hours} hours")
508
+ click.echo(f"Estimated cost: ${cost:.2f}")
509
+
510
+ @cli.command()
511
+ @click.option("--pod-id", "-p", help="Specific pod ID")
512
+ @click.option("--watch", "-w", is_flag=True, help="Watch mode (refresh every 10s)")
513
+ @click.option("--estimate", "-e", type=float, help="Estimate total for N hours")
514
+ def monitor(pod_id, watch, estimate):
515
+ """Monitor RunPod costs in real-time."""
516
+ if watch:
517
+ try:
518
+ while True:
519
+ os.system("clear" if os.name != "nt" else "cls")
520
+ pods = get_runpod_costs(pod_id)
521
+ print_runpod_report(pods, estimate)
522
+ click.echo("Press Ctrl+C to exit...")
523
+ time.sleep(10)
524
+ except KeyboardInterrupt:
525
+ click.echo("\nExiting...")
526
+ else:
527
+ pods = get_runpod_costs(pod_id)
528
+ print_runpod_report(pods, estimate)
529
+
530
+ cli()
531
+
532
+
533
+ if __name__ == "__main__":
534
+ main()
scripts/evaluate.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "underthesea[deep]>=6.8.0",
5
+ # "datasets>=2.14.0",
6
+ # "click>=8.0.0",
7
+ # "torch>=2.0.0",
8
+ # "transformers>=4.30.0",
9
+ # ]
10
+ # ///
11
+ """
12
+ Evaluation script for Bamboo-1 Vietnamese Dependency Parser.
13
+
14
+ Usage:
15
+ uv run scripts/evaluate.py --model models/bamboo-1
16
+ uv run scripts/evaluate.py --model models/bamboo-1 --split test
17
+ uv run scripts/evaluate.py --model models/bamboo-1 --detailed
18
+ """
19
+
20
+ import sys
21
+ from pathlib import Path
22
+ from collections import Counter
23
+
24
+ import click
25
+
26
+ # Add parent directory to path for bamboo1 module
27
+ sys.path.insert(0, str(Path(__file__).parent.parent))
28
+
29
+ from bamboo1.corpus import UDD1Corpus
30
+
31
+
32
+ def read_conll_sentences(filepath: str):
33
+ """Read sentences from a CoNLL-U file."""
34
+ sentences = []
35
+ current_sentence = []
36
+
37
+ with open(filepath, "r", encoding="utf-8") as f:
38
+ for line in f:
39
+ line = line.strip()
40
+ if line.startswith("#"):
41
+ continue
42
+ if not line:
43
+ if current_sentence:
44
+ sentences.append(current_sentence)
45
+ current_sentence = []
46
+ else:
47
+ parts = line.split("\t")
48
+ if len(parts) >= 8 and not "-" in parts[0] and not "." in parts[0]:
49
+ current_sentence.append({
50
+ "id": int(parts[0]),
51
+ "form": parts[1],
52
+ "upos": parts[3],
53
+ "head": int(parts[6]),
54
+ "deprel": parts[7],
55
+ })
56
+
57
+ if current_sentence:
58
+ sentences.append(current_sentence)
59
+
60
+ return sentences
61
+
62
+
63
+ def calculate_attachment_scores(gold_sentences, pred_sentences):
64
+ """Calculate UAS and LAS scores."""
65
+ total_tokens = 0
66
+ correct_heads = 0
67
+ correct_labels = 0
68
+
69
+ deprel_stats = Counter()
70
+ deprel_correct = Counter()
71
+
72
+ for gold_sent, pred_sent in zip(gold_sentences, pred_sentences):
73
+ for gold_tok, pred_tok in zip(gold_sent, pred_sent):
74
+ total_tokens += 1
75
+ deprel = gold_tok["deprel"]
76
+ deprel_stats[deprel] += 1
77
+
78
+ if gold_tok["head"] == pred_tok["head"]:
79
+ correct_heads += 1
80
+ if gold_tok["deprel"] == pred_tok["deprel"]:
81
+ correct_labels += 1
82
+ deprel_correct[deprel] += 1
83
+
84
+ uas = correct_heads / total_tokens if total_tokens > 0 else 0
85
+ las = correct_labels / total_tokens if total_tokens > 0 else 0
86
+
87
+ per_deprel_scores = {}
88
+ for deprel in deprel_stats:
89
+ if deprel_stats[deprel] > 0:
90
+ per_deprel_scores[deprel] = {
91
+ "total": deprel_stats[deprel],
92
+ "correct": deprel_correct[deprel],
93
+ "accuracy": deprel_correct[deprel] / deprel_stats[deprel],
94
+ }
95
+
96
+ return {
97
+ "uas": uas,
98
+ "las": las,
99
+ "total_tokens": total_tokens,
100
+ "correct_heads": correct_heads,
101
+ "correct_labels": correct_labels,
102
+ "per_deprel": per_deprel_scores,
103
+ }
104
+
105
+
106
+ @click.command()
107
+ @click.option(
108
+ "--model", "-m",
109
+ required=True,
110
+ help="Path to trained model directory",
111
+ )
112
+ @click.option(
113
+ "--split",
114
+ type=click.Choice(["dev", "test", "both"]),
115
+ default="test",
116
+ help="Dataset split to evaluate on",
117
+ show_default=True,
118
+ )
119
+ @click.option(
120
+ "--detailed",
121
+ is_flag=True,
122
+ help="Show detailed per-relation scores",
123
+ )
124
+ @click.option(
125
+ "--output", "-o",
126
+ help="Save predictions to file (CoNLL-U format)",
127
+ )
128
+ def evaluate(model, split, detailed, output):
129
+ """Evaluate Bamboo-1 Vietnamese Dependency Parser on UDD-1 dataset."""
130
+ from underthesea.models.dependency_parser import DependencyParser
131
+
132
+ click.echo("=" * 60)
133
+ click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation")
134
+ click.echo("=" * 60)
135
+
136
+ # Load model
137
+ click.echo(f"\nLoading model from {model}...")
138
+ parser = DependencyParser.load(model)
139
+
140
+ # Load corpus
141
+ click.echo("Loading UDD-1 corpus...")
142
+ corpus = UDD1Corpus()
143
+
144
+ splits_to_eval = []
145
+ if split == "both":
146
+ splits_to_eval = [("dev", corpus.dev), ("test", corpus.test)]
147
+ elif split == "dev":
148
+ splits_to_eval = [("dev", corpus.dev)]
149
+ else:
150
+ splits_to_eval = [("test", corpus.test)]
151
+
152
+ for split_name, split_path in splits_to_eval:
153
+ click.echo(f"\n{'=' * 40}")
154
+ click.echo(f"Evaluating on {split_name} set: {split_path}")
155
+ click.echo("=" * 40)
156
+
157
+ # Read gold data
158
+ gold_sentences = read_conll_sentences(split_path)
159
+ click.echo(f" Sentences: {len(gold_sentences)}")
160
+ click.echo(f" Tokens: {sum(len(s) for s in gold_sentences)}")
161
+
162
+ # Make predictions
163
+ click.echo("\nMaking predictions...")
164
+ pred_sentences = []
165
+
166
+ for gold_sent in gold_sentences:
167
+ # Reconstruct text from tokens
168
+ tokens = [tok["form"] for tok in gold_sent]
169
+ text = " ".join(tokens)
170
+
171
+ # Parse
172
+ result = parser.predict(text)
173
+
174
+ # Convert result to same format as gold
175
+ pred_sent = []
176
+ for i, (word, head, deprel) in enumerate(result):
177
+ pred_sent.append({
178
+ "id": i + 1,
179
+ "form": word,
180
+ "head": head,
181
+ "deprel": deprel,
182
+ })
183
+ pred_sentences.append(pred_sent)
184
+
185
+ # Calculate scores
186
+ scores = calculate_attachment_scores(gold_sentences, pred_sentences)
187
+
188
+ click.echo(f"\nResults:")
189
+ click.echo(f" UAS: {scores['uas']:.4f} ({scores['uas']*100:.2f}%)")
190
+ click.echo(f" LAS: {scores['las']:.4f} ({scores['las']*100:.2f}%)")
191
+ click.echo(f" Total tokens: {scores['total_tokens']}")
192
+ click.echo(f" Correct heads: {scores['correct_heads']}")
193
+ click.echo(f" Correct labels: {scores['correct_labels']}")
194
+
195
+ if detailed:
196
+ click.echo("\nPer-relation scores:")
197
+ click.echo("-" * 50)
198
+ click.echo(f"{'Relation':<15} {'Count':>8} {'Correct':>8} {'Accuracy':>10}")
199
+ click.echo("-" * 50)
200
+
201
+ for deprel in sorted(scores["per_deprel"].keys()):
202
+ stats = scores["per_deprel"][deprel]
203
+ click.echo(
204
+ f"{deprel:<15} {stats['total']:>8} {stats['correct']:>8} "
205
+ f"{stats['accuracy']*100:>9.2f}%"
206
+ )
207
+
208
+ # Save predictions if requested
209
+ if output:
210
+ out_path = Path(output)
211
+ if split_name != "test":
212
+ out_path = out_path.with_stem(f"{out_path.stem}_{split_name}")
213
+
214
+ click.echo(f"\nSaving predictions to {out_path}...")
215
+ with open(out_path, "w", encoding="utf-8") as f:
216
+ for i, (gold_sent, pred_sent) in enumerate(zip(gold_sentences, pred_sentences)):
217
+ f.write(f"# sent_id = {i + 1}\n")
218
+ for gold_tok, pred_tok in zip(gold_sent, pred_sent):
219
+ f.write(
220
+ f"{gold_tok['id']}\t{gold_tok['form']}\t_\t{gold_tok['upos']}\t_\t_\t"
221
+ f"{pred_tok['head']}\t{pred_tok['deprel']}\t_\t_\n"
222
+ )
223
+ f.write("\n")
224
+
225
+ click.echo("\nEvaluation complete!")
226
+
227
+
228
+ if __name__ == "__main__":
229
+ evaluate()
scripts/predict.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "underthesea[deep]>=6.8.0",
5
+ # "click>=8.0.0",
6
+ # "torch>=2.0.0",
7
+ # "transformers>=4.30.0",
8
+ # ]
9
+ # ///
10
+ """
11
+ Prediction script for Bamboo-1 Vietnamese Dependency Parser.
12
+
13
+ Usage:
14
+ # Interactive mode
15
+ uv run scripts/predict.py --model models/bamboo-1
16
+
17
+ # File input
18
+ uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu
19
+
20
+ # Single sentence
21
+ uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam"
22
+ """
23
+
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ import click
28
+
29
+
30
+ def format_tree_ascii(tokens, heads, deprels):
31
+ """Format dependency tree as ASCII art."""
32
+ n = len(tokens)
33
+ lines = []
34
+
35
+ # Header
36
+ lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n)))
37
+ lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens))
38
+
39
+ # Draw arcs
40
+ for i in range(n):
41
+ head = heads[i]
42
+ if head == 0:
43
+ lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})")
44
+ else:
45
+ arrow = "<-" if head > i + 1 else "->"
46
+ lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})")
47
+
48
+ return "\n".join(lines)
49
+
50
+
51
+ def format_conllu(tokens, heads, deprels, sent_id=None, text=None):
52
+ """Format result as CoNLL-U."""
53
+ lines = []
54
+ if sent_id:
55
+ lines.append(f"# sent_id = {sent_id}")
56
+ if text:
57
+ lines.append(f"# text = {text}")
58
+
59
+ for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
60
+ lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_")
61
+
62
+ lines.append("")
63
+ return "\n".join(lines)
64
+
65
+
66
+ @click.command()
67
+ @click.option(
68
+ "--model", "-m",
69
+ required=True,
70
+ help="Path to trained model directory",
71
+ )
72
+ @click.option(
73
+ "--input", "-i",
74
+ "input_file",
75
+ help="Input file (one sentence per line)",
76
+ )
77
+ @click.option(
78
+ "--output", "-o",
79
+ "output_file",
80
+ help="Output file (CoNLL-U format)",
81
+ )
82
+ @click.option(
83
+ "--text", "-t",
84
+ help="Single sentence to parse",
85
+ )
86
+ @click.option(
87
+ "--format",
88
+ "output_format",
89
+ type=click.Choice(["conllu", "simple", "tree"]),
90
+ default="simple",
91
+ help="Output format",
92
+ show_default=True,
93
+ )
94
+ def predict(model, input_file, output_file, text, output_format):
95
+ """Parse Vietnamese sentences with Bamboo-1 Dependency Parser."""
96
+ from underthesea.models.dependency_parser import DependencyParser
97
+
98
+ click.echo(f"Loading model from {model}...")
99
+ parser = DependencyParser.load(model)
100
+ click.echo("Model loaded.\n")
101
+
102
+ def parse_and_print(sentence, sent_id=None):
103
+ """Parse a sentence and print the result."""
104
+ result = parser.predict(sentence)
105
+ tokens = [r[0] for r in result]
106
+ heads = [r[1] for r in result]
107
+ deprels = [r[2] for r in result]
108
+
109
+ if output_format == "conllu":
110
+ return format_conllu(tokens, heads, deprels, sent_id, sentence)
111
+ elif output_format == "tree":
112
+ output = f"Sentence: {sentence}\n"
113
+ output += format_tree_ascii(tokens, heads, deprels)
114
+ return output
115
+ else: # simple
116
+ output = f"Input: {sentence}\n"
117
+ output += "Output:\n"
118
+ for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
119
+ head_word = "ROOT" if head == 0 else tokens[head - 1]
120
+ output += f" {i+1}. {token} -> {head_word} ({deprel})\n"
121
+ return output
122
+
123
+ # Single text mode
124
+ if text:
125
+ result = parse_and_print(text, sent_id=1)
126
+ click.echo(result)
127
+ return
128
+
129
+ # File mode
130
+ if input_file:
131
+ click.echo(f"Reading from {input_file}...")
132
+ with open(input_file, "r", encoding="utf-8") as f:
133
+ sentences = [line.strip() for line in f if line.strip()]
134
+
135
+ click.echo(f"Parsing {len(sentences)} sentences...")
136
+ results = []
137
+ for i, sentence in enumerate(sentences, 1):
138
+ result = parse_and_print(sentence, sent_id=i)
139
+ results.append(result)
140
+ if i % 100 == 0:
141
+ click.echo(f" Processed {i}/{len(sentences)}...")
142
+
143
+ if output_file:
144
+ with open(output_file, "w", encoding="utf-8") as f:
145
+ f.write("\n".join(results))
146
+ click.echo(f"Results saved to {output_file}")
147
+ else:
148
+ for result in results:
149
+ click.echo(result)
150
+ click.echo()
151
+ return
152
+
153
+ # Interactive mode
154
+ click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n")
155
+ sent_id = 1
156
+ while True:
157
+ try:
158
+ sentence = input(">>> ").strip()
159
+ if not sentence:
160
+ continue
161
+ result = parse_and_print(sentence, sent_id=sent_id)
162
+ click.echo(result)
163
+ click.echo()
164
+ sent_id += 1
165
+ except KeyboardInterrupt:
166
+ click.echo("\nGoodbye!")
167
+ break
168
+ except EOFError:
169
+ break
170
+
171
+
172
+ if __name__ == "__main__":
173
+ predict()
scripts/runpod_setup.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "runpod>=1.6.0",
5
+ # "requests>=2.28.0",
6
+ # ]
7
+ # ///
8
+ """
9
+ RunPod setup script for Bamboo-1 training.
10
+
11
+ Usage:
12
+ # Set your RunPod API key
13
+ export RUNPOD_API_KEY="your-api-key"
14
+
15
+ # Create a network volume for data
16
+ uv run scripts/runpod_setup.py volume-create --name bamboo-data --size 10
17
+
18
+ # List volumes
19
+ uv run scripts/runpod_setup.py volume-list
20
+
21
+ # Launch training pod with volume
22
+ uv run scripts/runpod_setup.py launch --volume <volume-id>
23
+
24
+ # Check pod status
25
+ uv run scripts/runpod_setup.py status
26
+
27
+ # Stop pod
28
+ uv run scripts/runpod_setup.py stop
29
+ """
30
+
31
+ import os
32
+ import click
33
+ import runpod
34
+ import requests
35
+
36
+
37
+ @click.group()
38
+ def cli():
39
+ """RunPod management for Bamboo-1 training."""
40
+ api_key = os.environ.get("RUNPOD_API_KEY")
41
+ if not api_key:
42
+ raise click.ClickException(
43
+ "RUNPOD_API_KEY environment variable not set.\n"
44
+ "Get your API key from https://runpod.io/console/user/settings"
45
+ )
46
+ runpod.api_key = api_key
47
+
48
+
49
+ def get_ssh_public_key() -> str:
50
+ """Get the user's SSH public key."""
51
+ from pathlib import Path
52
+ for key_file in ["~/.ssh/id_rsa.pub", "~/.ssh/id_ed25519.pub"]:
53
+ path = Path(key_file).expanduser()
54
+ if path.exists():
55
+ return path.read_text().strip()
56
+ return None
57
+
58
+
59
+ # Default images
60
+ DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
61
+ BAMBOO1_IMAGE = "undertheseanlp/bamboo-1:latest" # Pre-built image with dependencies
62
+
63
+
64
+ @cli.command()
65
+ @click.option("--gpu", default="NVIDIA RTX A4000", help="GPU type")
66
+ @click.option("--image", default=DEFAULT_IMAGE, help="Docker image")
67
+ @click.option("--prebuilt", is_flag=True, help="Use pre-built bamboo-1 image (faster startup)")
68
+ @click.option("--disk", default=20, type=int, help="Disk size in GB")
69
+ @click.option("--name", default="bamboo-1-training", help="Pod name")
70
+ @click.option("--volume", default=None, help="Network volume ID to attach")
71
+ @click.option("--wandb-key", envvar="WANDB_API_KEY", help="W&B API key for logging")
72
+ @click.option("--sample", default=0, type=int, help="Sample N sentences (0=all)")
73
+ @click.option("--epochs", default=100, type=int, help="Number of epochs")
74
+ def launch(gpu, image, prebuilt, disk, name, volume, wandb_key, sample, epochs):
75
+ """Launch a RunPod instance for training."""
76
+
77
+ # Use pre-built image if requested
78
+ if prebuilt:
79
+ image = BAMBOO1_IMAGE
80
+
81
+ click.echo("Launching RunPod instance...")
82
+ click.echo(f" GPU: {gpu}")
83
+ click.echo(f" Image: {image}")
84
+ click.echo(f" Disk: {disk}GB")
85
+
86
+ # Build training command
87
+ train_cmd = "uv run scripts/train.py"
88
+ if sample > 0:
89
+ train_cmd += f" --sample {sample}"
90
+ train_cmd += f" --epochs {epochs}"
91
+ if wandb_key:
92
+ train_cmd += " --wandb --wandb-project bamboo-1"
93
+
94
+ # Set environment variables
95
+ env_vars = {}
96
+ if wandb_key:
97
+ env_vars["WANDB_API_KEY"] = wandb_key
98
+
99
+ # Add SSH public key
100
+ ssh_key = get_ssh_public_key()
101
+ if ssh_key:
102
+ env_vars["PUBLIC_KEY"] = ssh_key
103
+ click.echo(" SSH key: configured")
104
+
105
+ if volume:
106
+ click.echo(f" Volume: {volume}")
107
+
108
+ pod = runpod.create_pod(
109
+ name=name,
110
+ image_name=image,
111
+ gpu_type_id=gpu,
112
+ volume_in_gb=disk,
113
+ env=env_vars if env_vars else None,
114
+ ports="22/tcp", # Expose SSH port
115
+ network_volume_id=volume, # Attach network volume
116
+ )
117
+
118
+ click.echo("\nPod created!")
119
+ click.echo(f" ID: {pod['id']}")
120
+ click.echo(f" Status: {pod.get('desiredStatus', 'PENDING')}")
121
+ click.echo("\nMonitor at: https://runpod.io/console/pods")
122
+
123
+ # Generate one-liner training command
124
+ click.echo("\n" + "="*60)
125
+ click.echo("SSH into the pod and run this command:")
126
+ click.echo("="*60)
127
+
128
+ if prebuilt:
129
+ # Pre-built image: dependencies already installed
130
+ one_liner = f"cd /workspace/bamboo-1 && {train_cmd}"
131
+ else:
132
+ # Standard image: need to install everything
133
+ one_liner = f"""curl -LsSf https://astral.sh/uv/install.sh | sh && source $HOME/.local/bin/env && git clone https://huggingface.co/undertheseanlp/bamboo-1 && cd bamboo-1 && uv sync && {train_cmd}"""
134
+
135
+ click.echo(one_liner)
136
+ click.echo("="*60)
137
+
138
+
139
+ @cli.command()
140
+ def status():
141
+ """Check status of all pods."""
142
+ pods = runpod.get_pods()
143
+
144
+ if not pods:
145
+ click.echo("No active pods.")
146
+ return
147
+
148
+ click.echo("Active pods:")
149
+ for pod in pods:
150
+ click.echo(f" - {pod['name']} ({pod['id']}): {pod.get('desiredStatus', 'UNKNOWN')}")
151
+
152
+
153
+ @cli.command()
154
+ @click.argument("pod_id")
155
+ def stop(pod_id):
156
+ """Stop a pod by ID."""
157
+ click.echo(f"Stopping pod {pod_id}...")
158
+ runpod.stop_pod(pod_id)
159
+ click.echo("Pod stopped.")
160
+
161
+
162
+ @cli.command()
163
+ @click.argument("pod_id")
164
+ def terminate(pod_id):
165
+ """Terminate a pod by ID."""
166
+ click.echo(f"Terminating pod {pod_id}...")
167
+ runpod.terminate_pod(pod_id)
168
+ click.echo("Pod terminated.")
169
+
170
+
171
+ # =============================================================================
172
+ # Volume Management
173
+ # =============================================================================
174
+
175
+ DATACENTERS = {
176
+ "EU-RO-1": "Europe (Romania)",
177
+ "EU-CZ-1": "Europe (Czech Republic)",
178
+ "EUR-IS-1": "Europe (Iceland)",
179
+ "US-KS-2": "US (Kansas)",
180
+ "US-CA-2": "US (California)",
181
+ }
182
+
183
+
184
+ def _graphql_request(query: str, variables: dict = None) -> dict:
185
+ """Make a GraphQL request to RunPod API."""
186
+ api_key = os.environ.get("RUNPOD_API_KEY")
187
+ response = requests.post(
188
+ "https://api.runpod.io/graphql",
189
+ headers={"Authorization": f"Bearer {api_key}"},
190
+ json={"query": query, "variables": variables or {}}
191
+ )
192
+ return response.json()
193
+
194
+
195
+ @cli.command("volume-list")
196
+ def volume_list():
197
+ """List all network volumes."""
198
+ query = """
199
+ query {
200
+ myself {
201
+ networkVolumes {
202
+ id
203
+ name
204
+ size
205
+ dataCenterId
206
+ }
207
+ }
208
+ }
209
+ """
210
+ result = _graphql_request(query)
211
+ volumes = result.get("data", {}).get("myself", {}).get("networkVolumes", [])
212
+
213
+ if not volumes:
214
+ click.echo("No network volumes found.")
215
+ click.echo(f"\nCreate one with: uv run scripts/runpod_setup.py volume-create --name bamboo-data --size 10")
216
+ return
217
+
218
+ click.echo("Network Volumes:")
219
+ for vol in volumes:
220
+ dc = DATACENTERS.get(vol['dataCenterId'], vol['dataCenterId'])
221
+ click.echo(f" - {vol['name']} ({vol['id']}): {vol['size']}GB @ {dc}")
222
+
223
+
224
+ @cli.command("volume-create")
225
+ @click.option("--name", default="bamboo-data", help="Volume name")
226
+ @click.option("--size", default=10, type=int, help="Size in GB")
227
+ @click.option("--datacenter", default="EUR-IS-1", type=click.Choice(list(DATACENTERS.keys())), help="Datacenter")
228
+ def volume_create(name, size, datacenter):
229
+ """Create a network volume for data storage."""
230
+ click.echo(f"Creating network volume...")
231
+ click.echo(f" Name: {name}")
232
+ click.echo(f" Size: {size}GB")
233
+ click.echo(f" Datacenter: {DATACENTERS[datacenter]}")
234
+
235
+ query = """
236
+ mutation createNetworkVolume($input: CreateNetworkVolumeInput!) {
237
+ createNetworkVolume(input: $input) {
238
+ id
239
+ name
240
+ size
241
+ dataCenterId
242
+ }
243
+ }
244
+ """
245
+ variables = {
246
+ "input": {
247
+ "name": name,
248
+ "size": size,
249
+ "dataCenterId": datacenter
250
+ }
251
+ }
252
+
253
+ result = _graphql_request(query, variables)
254
+
255
+ if "errors" in result:
256
+ click.echo(f"\nError: {result['errors'][0]['message']}")
257
+ return
258
+
259
+ volume = result.get("data", {}).get("createNetworkVolume", {})
260
+ click.echo(f"\nVolume created!")
261
+ click.echo(f" ID: {volume['id']}")
262
+ click.echo(f"\nUse with: uv run scripts/runpod_setup.py launch --volume {volume['id']}")
263
+
264
+
265
+ @cli.command("volume-delete")
266
+ @click.argument("volume_id")
267
+ @click.confirmation_option(prompt="Are you sure you want to delete this volume?")
268
+ def volume_delete(volume_id):
269
+ """Delete a network volume."""
270
+ query = """
271
+ mutation deleteNetworkVolume($input: DeleteNetworkVolumeInput!) {
272
+ deleteNetworkVolume(input: $input)
273
+ }
274
+ """
275
+ variables = {"input": {"id": volume_id}}
276
+
277
+ result = _graphql_request(query, variables)
278
+
279
+ if "errors" in result:
280
+ click.echo(f"Error: {result['errors'][0]['message']}")
281
+ return
282
+
283
+ click.echo(f"Volume {volume_id} deleted.")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ cli()
scripts/runpod_simple_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "runpod>=1.6.0",
5
+ # "click>=8.0.0",
6
+ # ]
7
+ # ///
8
+ """
9
+ Simple RunPod test script to verify API connection and GPU availability.
10
+
11
+ Usage:
12
+ export RUNPOD_API_KEY="your-api-key"
13
+ uv run scripts/runpod_simple_test.py
14
+ uv run scripts/runpod_simple_test.py --list-gpus
15
+ uv run scripts/runpod_simple_test.py --run-test
16
+ """
17
+
18
+ import os
19
+ import click
20
+ import runpod
21
+
22
+
23
+ @click.command()
24
+ @click.option("--list-gpus", is_flag=True, help="List available GPU types")
25
+ @click.option("--run-test", is_flag=True, help="Run a quick test pod")
26
+ def main(list_gpus, run_test):
27
+ """Test RunPod API connection and GPU availability."""
28
+ api_key = os.environ.get("RUNPOD_API_KEY")
29
+ if not api_key:
30
+ raise click.ClickException(
31
+ "Set RUNPOD_API_KEY environment variable.\n"
32
+ "Get your key at: https://runpod.io/console/user/settings"
33
+ )
34
+
35
+ runpod.api_key = api_key
36
+
37
+ # Test API connection
38
+ click.echo("Testing RunPod API connection...")
39
+ try:
40
+ pods = runpod.get_pods()
41
+ click.echo(f" Connected! Active pods: {len(pods)}")
42
+ except Exception as e:
43
+ raise click.ClickException(f"API connection failed: {e}")
44
+
45
+ # List GPUs
46
+ if list_gpus:
47
+ click.echo("\nAvailable GPU types:")
48
+ try:
49
+ gpus = runpod.get_gpus()
50
+ for gpu in gpus:
51
+ name = gpu.get("id", "Unknown")
52
+ mem = gpu.get("memoryInGb", "?")
53
+ click.echo(f" - {name} ({mem}GB)")
54
+ except Exception as e:
55
+ click.echo(f" Could not list GPUs: {e}")
56
+
57
+ # Run test pod
58
+ if run_test:
59
+ click.echo("\nLaunching test pod...")
60
+ test_script = "nvidia-smi && python3 -c 'import torch; print(f\"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}\")'"
61
+
62
+ pod = runpod.create_pod(
63
+ name="bamboo-1-test",
64
+ image_name="runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
65
+ gpu_type_id="NVIDIA RTX A4000",
66
+ volume_in_gb=5,
67
+ docker_args=f"bash -c '{test_script}; sleep 60'",
68
+ )
69
+
70
+ click.echo(f" Pod ID: {pod['id']}")
71
+ click.echo(f" Monitor: https://runpod.io/console/pods")
72
+ click.echo(f"\n Terminate after checking:")
73
+ click.echo(f" uv run scripts/runpod_setup.py terminate {pod['id']}")
74
+
75
+ if not list_gpus and not run_test:
76
+ click.echo("\nUse --list-gpus to see available GPUs")
77
+ click.echo("Use --run-test to launch a quick test pod")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
scripts/runpod_train.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # RunPod Training Script for Bamboo-1
3
+ # Usage: bash scripts/runpod_train.sh
4
+
5
+ set -e
6
+
7
+ echo "=========================================="
8
+ echo "Bamboo-1: Vietnamese Dependency Parser"
9
+ echo "RunPod Training Setup"
10
+ echo "=========================================="
11
+
12
+ # Install uv if not present
13
+ if ! command -v uv &> /dev/null; then
14
+ echo "Installing uv..."
15
+ curl -LsSf https://astral.sh/uv/install.sh | sh
16
+ source $HOME/.local/bin/env
17
+ fi
18
+
19
+ # Clone repo if not exists
20
+ if [ ! -d "bamboo-1" ]; then
21
+ echo "Cloning bamboo-1 from HuggingFace..."
22
+ git clone https://huggingface.co/undertheseanlp/bamboo-1
23
+ fi
24
+
25
+ cd bamboo-1
26
+
27
+ # Install dependencies
28
+ echo "Installing dependencies..."
29
+ uv sync
30
+
31
+ # Run training
32
+ echo "Starting training..."
33
+ uv run scripts/train.py \
34
+ --output models/bamboo-1-char \
35
+ --feat char \
36
+ --max-epochs 100 \
37
+ --batch-size 5000 \
38
+ --lr 2e-3 \
39
+ "$@"
40
+
41
+ echo "Training complete!"
42
+ echo "Model saved to: models/bamboo-1-char"
scripts/train.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "datasets>=2.14.0",
6
+ # "click>=8.0.0",
7
+ # "tqdm>=4.60.0",
8
+ # "wandb>=0.15.0",
9
+ # ]
10
+ # ///
11
+ """
12
+ Training script for Bamboo-1 Vietnamese Dependency Parser.
13
+ Biaffine parser implementation from scratch (Dozat & Manning, 2017).
14
+
15
+ Usage:
16
+ uv run scripts/train.py
17
+ uv run scripts/train.py --output models/bamboo-1 --epochs 100
18
+ """
19
+
20
+ import sys
21
+ from pathlib import Path
22
+ from collections import Counter
23
+ from dataclasses import dataclass
24
+ from typing import List, Tuple, Optional
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
30
+ from torch.utils.data import Dataset, DataLoader
31
+ from torch.optim import Adam
32
+ from torch.optim.lr_scheduler import ExponentialLR
33
+ from tqdm import tqdm
34
+
35
+ import click
36
+
37
+ sys.path.insert(0, str(Path(__file__).parent.parent))
38
+ from bamboo1.corpus import UDD1Corpus
39
+ from scripts.cost_estimate import CostTracker, detect_hardware
40
+
41
+
42
+ # ============================================================================
43
+ # Data Processing
44
+ # ============================================================================
45
+
46
+ @dataclass
47
+ class Sentence:
48
+ """A dependency-parsed sentence."""
49
+ words: List[str]
50
+ heads: List[int]
51
+ rels: List[str]
52
+
53
+
54
+ def read_conllu(path: str) -> List[Sentence]:
55
+ """Read CoNLL-U file and return list of sentences."""
56
+ sentences = []
57
+ words, heads, rels = [], [], []
58
+
59
+ with open(path, 'r', encoding='utf-8') as f:
60
+ for line in f:
61
+ line = line.strip()
62
+ if not line:
63
+ if words:
64
+ sentences.append(Sentence(words, heads, rels))
65
+ words, heads, rels = [], [], []
66
+ elif line.startswith('#'):
67
+ continue
68
+ else:
69
+ parts = line.split('\t')
70
+ if '-' in parts[0] or '.' in parts[0]: # Skip multi-word tokens
71
+ continue
72
+ words.append(parts[1]) # FORM
73
+ heads.append(int(parts[6])) # HEAD
74
+ rels.append(parts[7]) # DEPREL
75
+
76
+ if words:
77
+ sentences.append(Sentence(words, heads, rels))
78
+
79
+ return sentences
80
+
81
+
82
+ class Vocabulary:
83
+ """Vocabulary for words, characters, and relations."""
84
+ PAD = '<pad>'
85
+ UNK = '<unk>'
86
+
87
+ def __init__(self, min_freq: int = 2):
88
+ self.min_freq = min_freq
89
+ self.word2idx = {self.PAD: 0, self.UNK: 1}
90
+ self.char2idx = {self.PAD: 0, self.UNK: 1}
91
+ self.rel2idx = {}
92
+ self.idx2rel = {}
93
+
94
+ def build(self, sentences: List[Sentence]):
95
+ """Build vocabulary from sentences."""
96
+ word_counts = Counter()
97
+ char_counts = Counter()
98
+ rel_counts = Counter()
99
+
100
+ for sent in sentences:
101
+ for word in sent.words:
102
+ word_counts[word.lower()] += 1
103
+ for char in word:
104
+ char_counts[char] += 1
105
+ for rel in sent.rels:
106
+ rel_counts[rel] += 1
107
+
108
+ # Words
109
+ for word, count in word_counts.items():
110
+ if count >= self.min_freq and word not in self.word2idx:
111
+ self.word2idx[word] = len(self.word2idx)
112
+
113
+ # Characters
114
+ for char, count in char_counts.items():
115
+ if char not in self.char2idx:
116
+ self.char2idx[char] = len(self.char2idx)
117
+
118
+ # Relations
119
+ for rel in rel_counts:
120
+ if rel not in self.rel2idx:
121
+ idx = len(self.rel2idx)
122
+ self.rel2idx[rel] = idx
123
+ self.idx2rel[idx] = rel
124
+
125
+ def encode_word(self, word: str) -> int:
126
+ return self.word2idx.get(word.lower(), self.word2idx[self.UNK])
127
+
128
+ def encode_char(self, char: str) -> int:
129
+ return self.char2idx.get(char, self.char2idx[self.UNK])
130
+
131
+ def encode_rel(self, rel: str) -> int:
132
+ return self.rel2idx.get(rel, 0)
133
+
134
+ @property
135
+ def n_words(self) -> int:
136
+ return len(self.word2idx)
137
+
138
+ @property
139
+ def n_chars(self) -> int:
140
+ return len(self.char2idx)
141
+
142
+ @property
143
+ def n_rels(self) -> int:
144
+ return len(self.rel2idx)
145
+
146
+
147
+ class DependencyDataset(Dataset):
148
+ """Dataset for dependency parsing."""
149
+
150
+ def __init__(self, sentences: List[Sentence], vocab: Vocabulary):
151
+ self.sentences = sentences
152
+ self.vocab = vocab
153
+
154
+ def __len__(self):
155
+ return len(self.sentences)
156
+
157
+ def __getitem__(self, idx):
158
+ sent = self.sentences[idx]
159
+
160
+ # Encode words
161
+ word_ids = [self.vocab.encode_word(w) for w in sent.words]
162
+
163
+ # Encode characters
164
+ char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words]
165
+
166
+ # Heads and relations
167
+ heads = sent.heads
168
+ rels = [self.vocab.encode_rel(r) for r in sent.rels]
169
+
170
+ return word_ids, char_ids, heads, rels
171
+
172
+
173
+ def collate_fn(batch):
174
+ """Collate function for DataLoader."""
175
+ word_ids, char_ids, heads, rels = zip(*batch)
176
+
177
+ # Get lengths
178
+ lengths = [len(w) for w in word_ids]
179
+ max_len = max(lengths)
180
+
181
+ # Pad words
182
+ word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
183
+ for i, wids in enumerate(word_ids):
184
+ word_ids_padded[i, :len(wids)] = torch.tensor(wids)
185
+
186
+ # Pad characters
187
+ max_word_len = max(max(len(c) for c in chars) for chars in char_ids)
188
+ char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long)
189
+ for i, chars in enumerate(char_ids):
190
+ for j, c in enumerate(chars):
191
+ char_ids_padded[i, j, :len(c)] = torch.tensor(c)
192
+
193
+ # Pad heads
194
+ heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
195
+ for i, h in enumerate(heads):
196
+ heads_padded[i, :len(h)] = torch.tensor(h)
197
+
198
+ # Pad rels
199
+ rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
200
+ for i, r in enumerate(rels):
201
+ rels_padded[i, :len(r)] = torch.tensor(r)
202
+
203
+ # Mask
204
+ mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
205
+ for i, l in enumerate(lengths):
206
+ mask[i, :l] = True
207
+
208
+ lengths = torch.tensor(lengths)
209
+
210
+ return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths
211
+
212
+
213
+ # ============================================================================
214
+ # Model
215
+ # ============================================================================
216
+
217
+ class CharLSTM(nn.Module):
218
+ """Character-level LSTM embeddings."""
219
+
220
+ def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100):
221
+ super().__init__()
222
+ self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0)
223
+ self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True)
224
+ self.hidden_dim = hidden_dim
225
+
226
+ def forward(self, chars):
227
+ """
228
+ Args:
229
+ chars: (batch, seq_len, max_word_len)
230
+ Returns:
231
+ (batch, seq_len, hidden_dim)
232
+ """
233
+ batch, seq_len, max_word_len = chars.shape
234
+
235
+ # Flatten
236
+ chars_flat = chars.view(-1, max_word_len) # (batch * seq_len, max_word_len)
237
+
238
+ # Get word lengths
239
+ word_lens = (chars_flat != 0).sum(dim=1)
240
+ word_lens = word_lens.clamp(min=1)
241
+
242
+ # Embed
243
+ char_embeds = self.embed(chars_flat) # (batch * seq_len, max_word_len, char_dim)
244
+
245
+ # Pack and run LSTM
246
+ packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False)
247
+ _, (hidden, _) = self.lstm(packed)
248
+
249
+ # Concatenate forward and backward hidden states
250
+ hidden = torch.cat([hidden[0], hidden[1]], dim=-1) # (batch * seq_len, hidden_dim)
251
+
252
+ return hidden.view(batch, seq_len, self.hidden_dim)
253
+
254
+
255
+ class MLP(nn.Module):
256
+ """Multi-layer perceptron."""
257
+
258
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
259
+ super().__init__()
260
+ self.linear = nn.Linear(input_dim, hidden_dim)
261
+ self.activation = nn.LeakyReLU(0.1)
262
+ self.dropout = nn.Dropout(dropout)
263
+
264
+ def forward(self, x):
265
+ return self.dropout(self.activation(self.linear(x)))
266
+
267
+
268
+ class Biaffine(nn.Module):
269
+ """Biaffine attention layer."""
270
+
271
+ def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True):
272
+ super().__init__()
273
+ self.input_dim = input_dim
274
+ self.output_dim = output_dim
275
+ self.bias_x = bias_x
276
+ self.bias_y = bias_y
277
+
278
+ self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y))
279
+ nn.init.xavier_uniform_(self.weight)
280
+
281
+ def forward(self, x, y):
282
+ """
283
+ Args:
284
+ x: (batch, seq_len, input_dim) - dependent
285
+ y: (batch, seq_len, input_dim) - head
286
+ Returns:
287
+ (batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
288
+ """
289
+ if self.bias_x:
290
+ x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
291
+ if self.bias_y:
292
+ y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)
293
+
294
+ # (batch, seq_len, output_dim, input_dim+1)
295
+ x = torch.einsum('bxi,oij->bxoj', x, self.weight)
296
+ # (batch, seq_len, seq_len, output_dim)
297
+ scores = torch.einsum('bxoj,byj->bxyo', x, y)
298
+
299
+ if self.output_dim == 1:
300
+ scores = scores.squeeze(-1)
301
+
302
+ return scores
303
+
304
+
305
+ class BiaffineDependencyParser(nn.Module):
306
+ """Biaffine Dependency Parser (Dozat & Manning, 2017)."""
307
+
308
+ def __init__(
309
+ self,
310
+ n_words: int,
311
+ n_chars: int,
312
+ n_rels: int,
313
+ word_dim: int = 100,
314
+ char_dim: int = 50,
315
+ char_hidden: int = 100,
316
+ lstm_hidden: int = 400,
317
+ lstm_layers: int = 3,
318
+ arc_hidden: int = 500,
319
+ rel_hidden: int = 100,
320
+ dropout: float = 0.33,
321
+ ):
322
+ super().__init__()
323
+
324
+ self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0)
325
+ self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden)
326
+
327
+ input_dim = word_dim + char_hidden
328
+
329
+ self.lstm = nn.LSTM(
330
+ input_dim, lstm_hidden // 2,
331
+ num_layers=lstm_layers,
332
+ batch_first=True,
333
+ bidirectional=True,
334
+ dropout=dropout if lstm_layers > 1 else 0
335
+ )
336
+
337
+ self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout)
338
+ self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout)
339
+ self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout)
340
+ self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout)
341
+
342
+ self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
343
+ self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
344
+
345
+ self.dropout = nn.Dropout(dropout)
346
+ self.n_rels = n_rels
347
+
348
+ def forward(self, words, chars, mask):
349
+ """
350
+ Args:
351
+ words: (batch, seq_len)
352
+ chars: (batch, seq_len, max_word_len)
353
+ mask: (batch, seq_len)
354
+ Returns:
355
+ arc_scores: (batch, seq_len, seq_len)
356
+ rel_scores: (batch, seq_len, seq_len, n_rels)
357
+ """
358
+ # Embeddings
359
+ word_embeds = self.word_embed(words)
360
+ char_embeds = self.char_lstm(chars)
361
+ embeds = torch.cat([word_embeds, char_embeds], dim=-1)
362
+ embeds = self.dropout(embeds)
363
+
364
+ # BiLSTM
365
+ lengths = mask.sum(dim=1).cpu()
366
+ packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False)
367
+ lstm_out, _ = self.lstm(packed)
368
+ lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1))
369
+ lstm_out = self.dropout(lstm_out)
370
+
371
+ # MLP
372
+ arc_dep = self.mlp_arc_dep(lstm_out)
373
+ arc_head = self.mlp_arc_head(lstm_out)
374
+ rel_dep = self.mlp_rel_dep(lstm_out)
375
+ rel_head = self.mlp_rel_head(lstm_out)
376
+
377
+ # Biaffine
378
+ arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq_len, seq_len)
379
+ rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq_len, seq_len, n_rels)
380
+
381
+ return arc_scores, rel_scores
382
+
383
+ def loss(self, arc_scores, rel_scores, heads, rels, mask):
384
+ """Compute loss."""
385
+ batch_size, seq_len = mask.shape
386
+
387
+ # Arc loss
388
+ arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf'))
389
+ arc_loss = F.cross_entropy(
390
+ arc_scores[mask].view(-1, seq_len),
391
+ heads[mask],
392
+ reduction='mean'
393
+ )
394
+
395
+ # Rel loss - select scores for gold heads
396
+ rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads]
397
+ rel_loss = F.cross_entropy(
398
+ rel_scores_gold[mask],
399
+ rels[mask],
400
+ reduction='mean'
401
+ )
402
+
403
+ return arc_loss + rel_loss
404
+
405
+ def decode(self, arc_scores, rel_scores, mask):
406
+ """Decode predictions."""
407
+ # Greedy decoding
408
+ arc_preds = arc_scores.argmax(dim=-1)
409
+
410
+ batch_size, seq_len = mask.shape
411
+ rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds]
412
+ rel_preds = rel_scores_pred.argmax(dim=-1)
413
+
414
+ return arc_preds, rel_preds
415
+
416
+
417
+ # ============================================================================
418
+ # Training
419
+ # ============================================================================
420
+
421
+ def evaluate(model, dataloader, device):
422
+ """Evaluate model and return UAS/LAS."""
423
+ model.eval()
424
+
425
+ total_arcs = 0
426
+ correct_arcs = 0
427
+ correct_rels = 0
428
+
429
+ with torch.no_grad():
430
+ for batch in dataloader:
431
+ words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
432
+
433
+ arc_scores, rel_scores = model(words, chars, mask)
434
+ arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask)
435
+
436
+ # Count correct
437
+ arc_correct = (arc_preds == heads) & mask
438
+ rel_correct = (rel_preds == rels) & mask & arc_correct
439
+
440
+ total_arcs += mask.sum().item()
441
+ correct_arcs += arc_correct.sum().item()
442
+ correct_rels += rel_correct.sum().item()
443
+
444
+ uas = correct_arcs / total_arcs * 100
445
+ las = correct_rels / total_arcs * 100
446
+
447
+ return uas, las
448
+
449
+
450
+ @click.command()
451
+ @click.option('--output', '-o', default='models/bamboo-1', help='Output directory')
452
+ @click.option('--epochs', default=100, type=int, help='Number of epochs')
453
+ @click.option('--batch-size', default=32, type=int, help='Batch size')
454
+ @click.option('--lr', default=2e-3, type=float, help='Learning rate')
455
+ @click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size')
456
+ @click.option('--lstm-layers', default=3, type=int, help='LSTM layers')
457
+ @click.option('--patience', default=10, type=int, help='Early stopping patience')
458
+ @click.option('--force-download', is_flag=True, help='Force re-download dataset')
459
+ @click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation')
460
+ @click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds')
461
+ @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
462
+ @click.option('--wandb-project', default='bamboo-1', help='W&B project name')
463
+ @click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)')
464
+ @click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)')
465
+ def train(output, epochs, batch_size, lr, lstm_hidden, lstm_layers, patience, force_download, gpu_type, cost_interval, use_wandb, wandb_project, max_time, sample):
466
+ """Train Bamboo-1 Vietnamese Dependency Parser."""
467
+
468
+ # Detect hardware
469
+ hardware = detect_hardware()
470
+ detected_gpu_type = hardware.get_gpu_type()
471
+
472
+ # Use detected GPU type if not explicitly specified or if using default
473
+ if gpu_type == "RTX_A4000": # default value
474
+ gpu_type = detected_gpu_type
475
+
476
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
477
+ click.echo(f"Using device: {device}")
478
+ click.echo(f"Hardware: {hardware}")
479
+
480
+ # Initialize wandb
481
+ if use_wandb:
482
+ import wandb
483
+ wandb.init(
484
+ project=wandb_project,
485
+ config={
486
+ "epochs": epochs,
487
+ "batch_size": batch_size,
488
+ "lr": lr,
489
+ "lstm_hidden": lstm_hidden,
490
+ "lstm_layers": lstm_layers,
491
+ "patience": patience,
492
+ "gpu_type": gpu_type,
493
+ "hardware": hardware.to_dict(),
494
+ }
495
+ )
496
+ click.echo(f"W&B logging enabled: {wandb.run.url}")
497
+
498
+ click.echo("=" * 60)
499
+ click.echo("Bamboo-1: Vietnamese Dependency Parser")
500
+ click.echo("=" * 60)
501
+
502
+ # Load corpus
503
+ click.echo("\nLoading UDD-1 corpus...")
504
+ corpus = UDD1Corpus(force_download=force_download)
505
+
506
+ train_sents = read_conllu(corpus.train)
507
+ dev_sents = read_conllu(corpus.dev)
508
+ test_sents = read_conllu(corpus.test)
509
+
510
+ # Sample subset if requested
511
+ if sample > 0:
512
+ train_sents = train_sents[:sample]
513
+ dev_sents = dev_sents[:min(sample // 2, len(dev_sents))]
514
+ test_sents = test_sents[:min(sample // 2, len(test_sents))]
515
+ click.echo(f" Sampling {sample} sentences...")
516
+
517
+ click.echo(f" Train: {len(train_sents)} sentences")
518
+ click.echo(f" Dev: {len(dev_sents)} sentences")
519
+ click.echo(f" Test: {len(test_sents)} sentences")
520
+
521
+ # Build vocabulary
522
+ click.echo("\nBuilding vocabulary...")
523
+ vocab = Vocabulary(min_freq=2)
524
+ vocab.build(train_sents)
525
+ click.echo(f" Words: {vocab.n_words}")
526
+ click.echo(f" Chars: {vocab.n_chars}")
527
+ click.echo(f" Relations: {vocab.n_rels}")
528
+
529
+ # Create datasets
530
+ train_dataset = DependencyDataset(train_sents, vocab)
531
+ dev_dataset = DependencyDataset(dev_sents, vocab)
532
+ test_dataset = DependencyDataset(test_sents, vocab)
533
+
534
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
535
+ dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn)
536
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
537
+
538
+ # Create model
539
+ click.echo("\nInitializing model...")
540
+ model = BiaffineDependencyParser(
541
+ n_words=vocab.n_words,
542
+ n_chars=vocab.n_chars,
543
+ n_rels=vocab.n_rels,
544
+ lstm_hidden=lstm_hidden,
545
+ lstm_layers=lstm_layers,
546
+ ).to(device)
547
+
548
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
549
+ click.echo(f" Parameters: {n_params:,}")
550
+
551
+ # Optimizer
552
+ optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9))
553
+ scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000))
554
+
555
+ # Training
556
+ click.echo(f"\nTraining for {epochs} epochs...")
557
+ if max_time > 0:
558
+ click.echo(f"Time limit: {max_time} minutes")
559
+ output_path = Path(output)
560
+ output_path.mkdir(parents=True, exist_ok=True)
561
+
562
+ # Cost tracking
563
+ cost_tracker = CostTracker(gpu_type=gpu_type)
564
+ cost_tracker.report_interval = cost_interval
565
+ cost_tracker.start()
566
+ click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr")
567
+
568
+ best_las = -1
569
+ no_improve = 0
570
+ time_limit_seconds = max_time * 60 if max_time > 0 else float('inf')
571
+
572
+ for epoch in range(1, epochs + 1):
573
+ # Check time limit
574
+ if cost_tracker.elapsed_seconds() >= time_limit_seconds:
575
+ click.echo(f"\nTime limit reached ({max_time} minutes)")
576
+ break
577
+ model.train()
578
+ total_loss = 0
579
+
580
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False)
581
+ for batch in pbar:
582
+ words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
583
+
584
+ optimizer.zero_grad()
585
+ arc_scores, rel_scores = model(words, chars, mask)
586
+ loss = model.loss(arc_scores, rel_scores, heads, rels, mask)
587
+ loss.backward()
588
+ nn.utils.clip_grad_norm_(model.parameters(), 5.0)
589
+ optimizer.step()
590
+ scheduler.step()
591
+
592
+ total_loss += loss.item()
593
+ pbar.set_postfix({'loss': f'{loss.item():.4f}'})
594
+
595
+ # Evaluate
596
+ dev_uas, dev_las = evaluate(model, dev_loader, device)
597
+
598
+ # Cost update
599
+ progress = epoch / epochs
600
+ current_cost = cost_tracker.current_cost()
601
+ estimated_total_cost = cost_tracker.estimate_total_cost(progress)
602
+ elapsed_minutes = cost_tracker.elapsed_seconds() / 60
603
+
604
+ cost_status = cost_tracker.update(epoch, epochs)
605
+ if cost_status:
606
+ click.echo(f" [{cost_status}]")
607
+
608
+ avg_loss = total_loss / len(train_loader)
609
+ click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
610
+ f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%")
611
+
612
+ # Log to wandb
613
+ if use_wandb:
614
+ wandb.log({
615
+ "epoch": epoch,
616
+ "train/loss": avg_loss,
617
+ "dev/uas": dev_uas,
618
+ "dev/las": dev_las,
619
+ "cost/current_usd": current_cost,
620
+ "cost/estimated_total_usd": estimated_total_cost,
621
+ "cost/elapsed_minutes": elapsed_minutes,
622
+ })
623
+
624
+ # Save best model
625
+ if dev_las >= best_las:
626
+ best_las = dev_las
627
+ no_improve = 0
628
+ torch.save({
629
+ 'model': model.state_dict(),
630
+ 'vocab': vocab,
631
+ 'config': {
632
+ 'n_words': vocab.n_words,
633
+ 'n_chars': vocab.n_chars,
634
+ 'n_rels': vocab.n_rels,
635
+ 'lstm_hidden': lstm_hidden,
636
+ 'lstm_layers': lstm_layers,
637
+ }
638
+ }, output_path / 'model.pt')
639
+ click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)")
640
+ else:
641
+ no_improve += 1
642
+ if no_improve >= patience:
643
+ click.echo(f"\nEarly stopping after {patience} epochs without improvement")
644
+ break
645
+
646
+ # Final evaluation
647
+ click.echo("\nLoading best model for final evaluation...")
648
+ checkpoint = torch.load(output_path / 'model.pt', weights_only=False)
649
+ model.load_state_dict(checkpoint['model'])
650
+
651
+ test_uas, test_las = evaluate(model, test_loader, device)
652
+ click.echo(f"\nTest Results:")
653
+ click.echo(f" UAS: {test_uas:.2f}%")
654
+ click.echo(f" LAS: {test_las:.2f}%")
655
+
656
+ click.echo(f"\nModel saved to: {output_path}")
657
+
658
+ # Final cost summary
659
+ final_cost = cost_tracker.current_cost()
660
+ click.echo(f"\n{cost_tracker.summary(epoch, epochs)}")
661
+
662
+ # Log final metrics to wandb
663
+ if use_wandb:
664
+ wandb.log({
665
+ "test/uas": test_uas,
666
+ "test/las": test_las,
667
+ "cost/final_usd": final_cost,
668
+ })
669
+ wandb.finish()
670
+
671
+
672
+ if __name__ == '__main__':
673
+ train()
scripts/train_gpu.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "runpod>=1.6.0",
5
+ # "click>=8.0.0",
6
+ # ]
7
+ # ///
8
+ """
9
+ Simple GPU training script for Bamboo-1 using RunPod.
10
+
11
+ Usage:
12
+ export RUNPOD_API_KEY="your-api-key"
13
+ uv run scripts/train_gpu.py
14
+ uv run scripts/train_gpu.py --gpu "NVIDIA RTX 3090"
15
+ uv run scripts/train_gpu.py --feat bert --max-epochs 50
16
+ """
17
+
18
+ import os
19
+ import click
20
+ import runpod
21
+
22
+
23
+ @click.command()
24
+ @click.option("--gpu", default="NVIDIA RTX A4000", help="GPU type")
25
+ @click.option("--feat", type=click.Choice(["char", "bert"]), default="char", help="Feature type")
26
+ @click.option("--max-epochs", default=100, type=int, help="Max training epochs")
27
+ @click.option("--batch-size", default=5000, type=int, help="Tokens per batch")
28
+ @click.option("--name", default="bamboo-1-train", help="Pod name")
29
+ def main(gpu, feat, max_epochs, batch_size, name):
30
+ """Launch Bamboo-1 training on RunPod GPU."""
31
+ api_key = os.environ.get("RUNPOD_API_KEY")
32
+ if not api_key:
33
+ raise click.ClickException(
34
+ "Set RUNPOD_API_KEY environment variable.\n"
35
+ "Get your key at: https://runpod.io/console/user/settings"
36
+ )
37
+
38
+ runpod.api_key = api_key
39
+
40
+ # One-liner to avoid string escaping issues
41
+ train_cmd = (
42
+ f"curl -LsSf https://astral.sh/uv/install.sh | sh && "
43
+ f"source $HOME/.local/bin/env && "
44
+ f"git clone https://huggingface.co/undertheseanlp/bamboo-1 && "
45
+ f"cd bamboo-1 && "
46
+ f"uv sync && "
47
+ f"uv run scripts/train.py --output models/bamboo-1 --feat {feat} --max-epochs {max_epochs} --batch-size {batch_size}"
48
+ )
49
+
50
+ click.echo("Launching RunPod training...")
51
+ click.echo(f" GPU: {gpu}")
52
+ click.echo(f" Feature: {feat}")
53
+ click.echo(f" Epochs: {max_epochs}")
54
+
55
+ pod = runpod.create_pod(
56
+ name=name,
57
+ image_name="runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
58
+ gpu_type_id=gpu,
59
+ volume_in_gb=20,
60
+ docker_args=train_cmd,
61
+ )
62
+
63
+ click.echo(f"\nPod launched!")
64
+ click.echo(f" ID: {pod['id']}")
65
+ click.echo(f" Monitor: https://runpod.io/console/pods")
66
+ click.echo(f"\nTo stop: uv run scripts/runpod_setup.py terminate {pod['id']}")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
scripts/watch_pod.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "runpod>=1.6.0",
5
+ # "click>=8.0.0",
6
+ # ]
7
+ # ///
8
+ """
9
+ Watch RunPod pod status.
10
+
11
+ Usage:
12
+ export $(cat .env | xargs) && uv run scripts/watch_pod.py
13
+ export $(cat .env | xargs) && uv run scripts/watch_pod.py --pod-id <id>
14
+ """
15
+
16
+ import os
17
+ import time
18
+ import click
19
+ import runpod
20
+ from runpod.api.graphql import run_graphql_query
21
+
22
+
23
+ def get_pod_status(pod_id):
24
+ query = f'''
25
+ query getPodStatus {{
26
+ pod(input: {{ podId: "{pod_id}" }}) {{
27
+ id
28
+ name
29
+ desiredStatus
30
+ runtime {{
31
+ uptimeInSeconds
32
+ gpus {{
33
+ gpuUtilPercent
34
+ memoryUtilPercent
35
+ }}
36
+ container {{
37
+ cpuPercent
38
+ memoryPercent
39
+ }}
40
+ }}
41
+ }}
42
+ }}
43
+ '''
44
+ return run_graphql_query(query)
45
+
46
+
47
+ @click.command()
48
+ @click.option("--pod-id", default=None, help="Pod ID to watch")
49
+ @click.option("--interval", default=10, type=int, help="Refresh interval in seconds")
50
+ def main(pod_id, interval):
51
+ """Watch RunPod pod status in real-time."""
52
+ api_key = os.environ.get("RUNPOD_API_KEY")
53
+ if not api_key:
54
+ raise click.ClickException("Set RUNPOD_API_KEY")
55
+
56
+ runpod.api_key = api_key
57
+
58
+ # Get pod ID if not provided
59
+ if not pod_id:
60
+ pods = runpod.get_pods()
61
+ if not pods:
62
+ click.echo("No active pods found.")
63
+ return
64
+ pod_id = pods[0]["id"]
65
+ click.echo(f"Watching pod: {pods[0].get('name', pod_id)}")
66
+
67
+ click.echo(f"Refreshing every {interval}s. Press Ctrl+C to stop.\n")
68
+
69
+ try:
70
+ while True:
71
+ result = get_pod_status(pod_id)
72
+ pod = result.get("data", {}).get("pod")
73
+
74
+ if not pod:
75
+ click.echo("Pod not found or terminated.")
76
+ break
77
+
78
+ # Clear and print status
79
+ click.clear()
80
+ click.echo(f"=== {pod['name']} ({pod['id']}) ===")
81
+ click.echo(f"Status: {pod['desiredStatus']}")
82
+
83
+ runtime = pod.get("runtime") or {}
84
+ uptime = runtime.get("uptimeInSeconds", 0)
85
+ mins, secs = divmod(uptime, 60)
86
+ hours, mins = divmod(mins, 60)
87
+ click.echo(f"Uptime: {int(hours)}h {int(mins)}m {int(secs)}s")
88
+
89
+ gpus = runtime.get("gpus") or []
90
+ if gpus:
91
+ gpu = gpus[0]
92
+ click.echo(f"GPU Util: {gpu.get('gpuUtilPercent', 0):.1f}%")
93
+ click.echo(f"GPU Mem: {gpu.get('memoryUtilPercent', 0):.1f}%")
94
+
95
+ container = runtime.get("container") or {}
96
+ click.echo(f"CPU: {container.get('cpuPercent', 0):.1f}%")
97
+ click.echo(f"Memory: {container.get('memoryPercent', 0):.1f}%")
98
+
99
+ click.echo(f"\nLast update: {time.strftime('%H:%M:%S')}")
100
+ click.echo("Press Ctrl+C to stop")
101
+
102
+ if pod["desiredStatus"] not in ["RUNNING", "STARTING"]:
103
+ click.echo(f"\nPod is {pod['desiredStatus']}. Stopping watch.")
104
+ break
105
+
106
+ time.sleep(interval)
107
+
108
+ except KeyboardInterrupt:
109
+ click.echo("\nStopped watching.")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff