Mindigenous commited on
Commit ·
5ae3e12
1
Parent(s): 7a24ed3
Sync latest workspace state: data/scripts updates and archive cleanup
Browse files- AMD Cloud Logs.txt +8 -1
- apps_converter.py +127 -0
- backup_step8500.tar.gz +0 -3
- backup_step8750.tar.gz +0 -3
- backup_step9000.tar.gz +0 -3
- backup_step9250.tar.gz +0 -3
- backup_step9500.tar.gz +0 -3
- backup_step9750.tar.gz +0 -3
- backups/backup_step4250.tar.gz +0 -3
- backups/backup_step4500.tar.gz +0 -3
- backups/backup_step4750.tar.gz +0 -3
- backups/backup_step5000.tar.gz +0 -3
- backups/backup_step5250.tar.gz +0 -3
- backups/backup_step5500.tar.gz +0 -3
- backups/backup_step5750.tar.gz +0 -3
- backups/backup_step6000.tar.gz +0 -3
- backups/backup_step6250.tar.gz +0 -3
- backups/backup_step6500.tar.gz +0 -3
- backups/backup_step6750.tar.gz +0 -3
- backups/backup_step7000.tar.gz +0 -3
- backups/backup_step7250.tar.gz +0 -3
- backups/backup_step7500.tar.gz +0 -3
- backups/backup_step7750.tar.gz +0 -3
- checkpoints/component5_420m/latest.pt +0 -3
- checkpoints/component5_420m/step_3000.pt +0 -3
- checkpoints/component5_420m/step_3200.pt +0 -3
- codeforces_ingest.py +288 -0
- backup_step1000.tar.gz → data/final/_rebalance_tmp/instruction.jsonl +2 -2
- backup_step2000.tar.gz → data/final/_rebalance_tmp/problem.jsonl +2 -2
- backup_step8000.tar.gz → data/final/_rebalance_tmp/rebalance_seen.sqlite +2 -2
- backup_step3000.tar.gz → data/final/_rebalance_tmp/structured.jsonl +2 -2
- backup_step8250.tar.gz → data/final/dedupe_hashes.sqlite +2 -2
- data/final/train.jsonl +3 -0
- data/raw/custom_finetune_pairs.jsonl +0 -3
- data_fetch.py +1025 -170
- dataset_cleaner.py +342 -0
- dataset_formatter.py +102 -0
- final_model/config.json +29 -0
- final_model/configuration_mindi.py +38 -0
- final_model/generation_config.json +10 -0
- backup_step4000.tar.gz → final_model/model.safetensors +2 -2
- final_model/modeling_mindi.py +219 -0
- final_model/tokenization_mindi.py +33 -0
- final_model/tokenizer.json +799 -0
- final_model/tokenizer_config.json +191 -0
- logs/data_fetch.log +2 -2
- merge.py +29 -0
- requirements.txt +1 -0
- test.py +35 -0
AMD Cloud Logs.txt
CHANGED
|
@@ -408,4 +408,11 @@ trainable params: 7,630,848 || all params: 431,565,696 || trainable%: 1.7682
|
|
| 408 |
{'loss': 6.2728, 'grad_norm': 4.490257740020752, 'learning_rate': 7.265401482791907e-06, 'epoch': 1.85}
|
| 409 |
{'loss': 6.4827, 'grad_norm': 4.102600574493408, 'learning_rate': 7.250086658147697e-06, 'epoch': 1.85}
|
| 410 |
{'loss': 6.2786, 'grad_norm': 4.251227378845215, 'learning_rate': 7.23474531713807e-06, 'epoch': 1.86}
|
| 411 |
-
37%|█████████████████████████████████████████████████████████████████▉ | 7028/18870 [44:42<2:03:45, 1.59it/s]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
{'loss': 6.2728, 'grad_norm': 4.490257740020752, 'learning_rate': 7.265401482791907e-06, 'epoch': 1.85}
|
| 409 |
{'loss': 6.4827, 'grad_norm': 4.102600574493408, 'learning_rate': 7.250086658147697e-06, 'epoch': 1.85}
|
| 410 |
{'loss': 6.2786, 'grad_norm': 4.251227378845215, 'learning_rate': 7.23474531713807e-06, 'epoch': 1.86}
|
| 411 |
+
37%|█████████████████████████████████████████████████████████████████▉ | 7028/18870 [44:42<2:03:45, 1.59it/s]
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
apt update && apt install -y git-lfs
|
| 416 |
+
git clone https://huggingface.co/Mindigenous/mindi-backup
|
| 417 |
+
cd mindi-backup
|
| 418 |
+
tar -xzvf backup_step12000.tar.gz
|
apps_converter.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Iterable, List, Tuple
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 10 |
+
INPUT_FILES = [
|
| 11 |
+
PROJECT_ROOT / "apps" / "train.jsonl",
|
| 12 |
+
PROJECT_ROOT / "apps" / "test.jsonl",
|
| 13 |
+
]
|
| 14 |
+
OUTPUT_FILE = PROJECT_ROOT / "data" / "raw" / "apps.jsonl"
|
| 15 |
+
|
| 16 |
+
MAX_SOLUTIONS_PER_PROBLEM = 2
|
| 17 |
+
MIN_RESPONSE_CHARS = 20
|
| 18 |
+
MAX_RESPONSE_TOKENS = 3000
|
| 19 |
+
CODE_HINT_RE = re.compile(
|
| 20 |
+
r"(\bdef\s+\w+\s*\(|\bclass\s+\w+|\bfor\s+\w+\s+in\b|\bwhile\b|[{;}]|\breturn\b|\bimport\b)",
|
| 21 |
+
re.IGNORECASE,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _normalize_text(value: str) -> str:
|
| 26 |
+
return value.strip()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _parse_solutions(raw_solutions) -> List[str]:
|
| 30 |
+
if raw_solutions is None:
|
| 31 |
+
return []
|
| 32 |
+
if isinstance(raw_solutions, list):
|
| 33 |
+
return [str(x) for x in raw_solutions if x is not None]
|
| 34 |
+
if isinstance(raw_solutions, str):
|
| 35 |
+
raw_solutions = raw_solutions.strip()
|
| 36 |
+
if not raw_solutions:
|
| 37 |
+
return []
|
| 38 |
+
try:
|
| 39 |
+
parsed = json.loads(raw_solutions)
|
| 40 |
+
if isinstance(parsed, list):
|
| 41 |
+
return [str(x) for x in parsed if x is not None]
|
| 42 |
+
if isinstance(parsed, str):
|
| 43 |
+
return [parsed]
|
| 44 |
+
return []
|
| 45 |
+
except json.JSONDecodeError:
|
| 46 |
+
return [raw_solutions]
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _is_code_like(text: str) -> bool:
|
| 51 |
+
return bool(CODE_HINT_RE.search(text))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _iter_jsonl(path: Path) -> Iterable[dict]:
|
| 55 |
+
with path.open("r", encoding="utf-8", errors="ignore") as f:
|
| 56 |
+
for line in f:
|
| 57 |
+
line = line.strip()
|
| 58 |
+
if not line:
|
| 59 |
+
continue
|
| 60 |
+
try:
|
| 61 |
+
obj = json.loads(line)
|
| 62 |
+
except json.JSONDecodeError:
|
| 63 |
+
continue
|
| 64 |
+
if isinstance(obj, dict):
|
| 65 |
+
yield obj
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def convert_apps_dataset(input_files: List[Path], output_file: Path) -> Tuple[int, int, int]:
|
| 69 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
total_input_samples = 0
|
| 72 |
+
valid_output_samples = 0
|
| 73 |
+
skipped_samples = 0
|
| 74 |
+
|
| 75 |
+
with output_file.open("w", encoding="utf-8") as out_f:
|
| 76 |
+
for input_path in input_files:
|
| 77 |
+
if not input_path.exists():
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
for item in tqdm(_iter_jsonl(input_path), desc=f"apps:{input_path.name}", unit="rows"):
|
| 81 |
+
total_input_samples += 1
|
| 82 |
+
|
| 83 |
+
question = _normalize_text(str(item.get("question", "")))
|
| 84 |
+
if not question:
|
| 85 |
+
skipped_samples += 1
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
all_solutions = _parse_solutions(item.get("solutions"))
|
| 89 |
+
if not all_solutions:
|
| 90 |
+
skipped_samples += 1
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
usable = 0
|
| 94 |
+
for raw_solution in all_solutions:
|
| 95 |
+
solution = _normalize_text(raw_solution)
|
| 96 |
+
if not solution:
|
| 97 |
+
continue
|
| 98 |
+
if len(solution) < MIN_RESPONSE_CHARS:
|
| 99 |
+
continue
|
| 100 |
+
if len(solution.split()) > MAX_RESPONSE_TOKENS:
|
| 101 |
+
continue
|
| 102 |
+
if not _is_code_like(solution):
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
row = {
|
| 106 |
+
"instruction": f"Solve the following problem:\n{question}",
|
| 107 |
+
"response": solution,
|
| 108 |
+
}
|
| 109 |
+
out_f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 110 |
+
valid_output_samples += 1
|
| 111 |
+
usable += 1
|
| 112 |
+
if usable >= MAX_SOLUTIONS_PER_PROBLEM:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if usable == 0:
|
| 116 |
+
skipped_samples += 1
|
| 117 |
+
|
| 118 |
+
return total_input_samples, valid_output_samples, skipped_samples
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
total_input, valid_output, skipped = convert_apps_dataset(INPUT_FILES, OUTPUT_FILE)
|
| 123 |
+
print(f"Output: {OUTPUT_FILE}")
|
| 124 |
+
print(f"Total input samples: {total_input}")
|
| 125 |
+
print(f"Valid output samples: {valid_output}")
|
| 126 |
+
print(f"Skipped samples: {skipped}")
|
| 127 |
+
print("APPS dataset ready for training pipeline")
|
backup_step8500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d35f24763676911a2f605ff63e56a62f521bde805757d51b2e356a004d479e2e
|
| 3 |
-
size 84695943
|
|
|
|
|
|
|
|
|
|
|
|
backup_step8750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:007068138f8a165ff5a3fea9ed096a94bdf620d0007b013d8834d69bfc650628
|
| 3 |
-
size 84696682
|
|
|
|
|
|
|
|
|
|
|
|
backup_step9000.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a69b305a69b77ea66f9feeaaaa3bbd7c4a08f7111bbd6cdd3b90e2e59a5b2e7b
|
| 3 |
-
size 84704097
|
|
|
|
|
|
|
|
|
|
|
|
backup_step9250.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f8724eceedfd4f8c4f87a14f1fa8c2019bcbfe9af6165e57aac020bb04c65fd5
|
| 3 |
-
size 84699876
|
|
|
|
|
|
|
|
|
|
|
|
backup_step9500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a728fcf9931e37ae37a3db4044170a254473aa08f9a10e958ce88987f2575d8c
|
| 3 |
-
size 84705286
|
|
|
|
|
|
|
|
|
|
|
|
backup_step9750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:fe8bbef08bb3ee21de186753bce613d4b050b4011d85378737d464e190db65a7
|
| 3 |
-
size 84703357
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step4250.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f95bd63ebd8df351b262311c6400b60263f5e975f3022810c432db7207c3d92c
|
| 3 |
-
size 84560529
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step4500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b2494fb51792518a027ae7fa403139cb96ada8218c3a1dc85b131850cbf98ed0
|
| 3 |
-
size 84567347
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step4750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1971eeabb6531dbbe399c5173e830262e6c0c8708e020ba52f5090e1370a01e8
|
| 3 |
-
size 84587608
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step5000.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2100ab540c449907be5a9403489969e63c6e25faf7d6d62d81de95977091634a
|
| 3 |
-
size 84605420
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step5250.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e556e3898a018791ec7afeda40574ea0b192d463c9a306e4d2d30384234f1d5d
|
| 3 |
-
size 84614578
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step5500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:07e6f0821e8a1ef8568d7733beefc43aa01a88c49f60344c0d0d508ea60c8776
|
| 3 |
-
size 84617900
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step5750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:6fd42608ddcc28d31c9144fb60c949dcc16cfc1fd90b99bd4a9dc2d059354318
|
| 3 |
-
size 84628951
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step6000.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8643c889c56e07ff862cf9300c33b14fd9eb0d8d130dd3c837ce99bef0c5accb
|
| 3 |
-
size 84638746
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step6250.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:6745b7c92df7b4a915e3a62064c74a537e85f4cac701c98eb798b9e4039db4e8
|
| 3 |
-
size 84646702
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step6500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:329dbcfd21ab74c713da3b8a6e4dc35bde1d863200b99cdcd183a6c499a18d3c
|
| 3 |
-
size 84646328
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step6750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d3e9ac1b4dd644738bdcc09bdd559bafe9e8fb2910dac242a78516b38919af34
|
| 3 |
-
size 84660003
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step7000.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ad80a05d74d8ca176a7620613a10eca11332d176d5cd8be89b284a521b5409f9
|
| 3 |
-
size 84664111
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step7250.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:282b3f77f6a892bd8813fd280f0bfd5ebb80545aca2f5614b0f3aa85358fb46a
|
| 3 |
-
size 84667605
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step7500.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3e523cd3c7479b88b40f84b7f46c75b749f712d61e1223d068f25ea330c428ed
|
| 3 |
-
size 84668063
|
|
|
|
|
|
|
|
|
|
|
|
backups/backup_step7750.tar.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:775b3f4094fbf31da00a81b1f76f84efed709a17c5da27c5b2d9c8b8cb11389b
|
| 3 |
-
size 84688475
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/component5_420m/latest.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:32d26a7dd9e6e294c6657f6fb3a4d947cf52eb8e1c0b11032722fa50d15c4a21
|
| 3 |
-
size 5087449970
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/component5_420m/step_3000.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e11bded40789574ef316636c02c2fd1e8cd54c13441d8cd6a28980f2209ffaa9
|
| 3 |
-
size 5087455158
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/component5_420m/step_3200.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:71d2ea9401f3b08b2528dbb8f993949794d0adb57642d0f4752d74da0e445238
|
| 3 |
-
size 5087455158
|
|
|
|
|
|
|
|
|
|
|
|
codeforces_ingest.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import html
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import sqlite3
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Iterable, Iterator, List, Tuple
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
HTML_TAG_RE = re.compile(r"<[^>]+>")
|
| 15 |
+
WS_RE = re.compile(r"\s+")
|
| 16 |
+
CODE_HINT_RE = re.compile(
|
| 17 |
+
r"(\bdef\b|\bclass\b|#include|public\s+class|function\s+\w+|\breturn\b|\bimport\b)",
|
| 18 |
+
re.IGNORECASE,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def setup_logger(log_path: Path) -> logging.Logger:
|
| 23 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
logger = logging.getLogger("codeforces_ingest")
|
| 25 |
+
logger.setLevel(logging.INFO)
|
| 26 |
+
if logger.handlers:
|
| 27 |
+
return logger
|
| 28 |
+
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
|
| 29 |
+
fh = logging.FileHandler(log_path, encoding="utf-8")
|
| 30 |
+
fh.setFormatter(formatter)
|
| 31 |
+
sh = logging.StreamHandler()
|
| 32 |
+
sh.setFormatter(formatter)
|
| 33 |
+
logger.addHandler(fh)
|
| 34 |
+
logger.addHandler(sh)
|
| 35 |
+
return logger
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def clean_text(text: str) -> str:
|
| 39 |
+
if not text:
|
| 40 |
+
return ""
|
| 41 |
+
text = html.unescape(str(text))
|
| 42 |
+
text = HTML_TAG_RE.sub(" ", text)
|
| 43 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
| 44 |
+
text = "\n".join(line.strip() for line in text.split("\n"))
|
| 45 |
+
text = WS_RE.sub(" ", text)
|
| 46 |
+
return text.strip()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _safe_get(record: Dict[str, object], keys: Iterable[str]) -> str:
|
| 50 |
+
for key in keys:
|
| 51 |
+
val = record.get(key)
|
| 52 |
+
if val:
|
| 53 |
+
return str(val)
|
| 54 |
+
return ""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _extract_pair(record: Dict[str, object]) -> Tuple[str, str]:
|
| 58 |
+
problem = _safe_get(
|
| 59 |
+
record,
|
| 60 |
+
[
|
| 61 |
+
"problem_statement",
|
| 62 |
+
"statement",
|
| 63 |
+
"problem",
|
| 64 |
+
"question",
|
| 65 |
+
"content",
|
| 66 |
+
"description",
|
| 67 |
+
"prompt",
|
| 68 |
+
"instruction",
|
| 69 |
+
],
|
| 70 |
+
)
|
| 71 |
+
solution = _safe_get(
|
| 72 |
+
record,
|
| 73 |
+
[
|
| 74 |
+
"solution",
|
| 75 |
+
"solution_code",
|
| 76 |
+
"answer",
|
| 77 |
+
"code",
|
| 78 |
+
"response",
|
| 79 |
+
"python",
|
| 80 |
+
"cpp",
|
| 81 |
+
"java",
|
| 82 |
+
"javascript",
|
| 83 |
+
],
|
| 84 |
+
)
|
| 85 |
+
return clean_text(problem), clean_text(solution)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _iter_jsonl(path: Path) -> Iterator[Dict[str, object]]:
|
| 89 |
+
with path.open("r", encoding="utf-8", errors="ignore") as f:
|
| 90 |
+
for line in f:
|
| 91 |
+
line = line.strip()
|
| 92 |
+
if not line:
|
| 93 |
+
continue
|
| 94 |
+
try:
|
| 95 |
+
obj = json.loads(line)
|
| 96 |
+
if isinstance(obj, dict):
|
| 97 |
+
yield obj
|
| 98 |
+
except Exception:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _iter_json_stream(path: Path, logger: logging.Logger) -> Iterator[Dict[str, object]]:
|
| 103 |
+
# Streaming JSON parser for array/object JSON files if ijson is available.
|
| 104 |
+
try:
|
| 105 |
+
import ijson # type: ignore
|
| 106 |
+
except Exception:
|
| 107 |
+
logger.warning("Skipping JSON file without ijson installed (to preserve streaming): %s", path)
|
| 108 |
+
return iter(())
|
| 109 |
+
|
| 110 |
+
def gen():
|
| 111 |
+
with path.open("rb") as f:
|
| 112 |
+
try:
|
| 113 |
+
for obj in ijson.items(f, "item"):
|
| 114 |
+
if isinstance(obj, dict):
|
| 115 |
+
yield obj
|
| 116 |
+
except Exception:
|
| 117 |
+
# Some files may be top-level dicts with nested lists.
|
| 118 |
+
f.seek(0)
|
| 119 |
+
try:
|
| 120 |
+
root = next(ijson.items(f, ""))
|
| 121 |
+
if isinstance(root, dict):
|
| 122 |
+
for v in root.values():
|
| 123 |
+
if isinstance(v, list):
|
| 124 |
+
for obj in v:
|
| 125 |
+
if isinstance(obj, dict):
|
| 126 |
+
yield obj
|
| 127 |
+
except Exception:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
return gen()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _iter_csv_like(path: Path, delimiter: str) -> Iterator[Dict[str, object]]:
|
| 134 |
+
with path.open("r", encoding="utf-8", errors="ignore", newline="") as f:
|
| 135 |
+
reader = csv.DictReader(f, delimiter=delimiter)
|
| 136 |
+
for row in reader:
|
| 137 |
+
if isinstance(row, dict):
|
| 138 |
+
yield row
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _iter_txt_records(path: Path) -> Iterator[Dict[str, object]]:
|
| 142 |
+
# Heuristic fallback for text dumps:
|
| 143 |
+
# split on obvious separators and map to pseudo records.
|
| 144 |
+
with path.open("r", encoding="utf-8", errors="ignore") as f:
|
| 145 |
+
blob = f.read()
|
| 146 |
+
chunks = re.split(r"\n\s*[-=]{3,}\s*\n|\n\s*Problem\s+\d+\s*\n", blob, flags=re.IGNORECASE)
|
| 147 |
+
for chunk in chunks:
|
| 148 |
+
chunk = chunk.strip()
|
| 149 |
+
if len(chunk) < 120:
|
| 150 |
+
continue
|
| 151 |
+
parts = re.split(r"\n\s*(Solution|Answer)\s*:\s*\n", chunk, flags=re.IGNORECASE)
|
| 152 |
+
if len(parts) < 3:
|
| 153 |
+
continue
|
| 154 |
+
yield {"problem_statement": parts[0], "solution": parts[2]}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _iter_candidate_files(input_dir: Path) -> Iterator[Path]:
|
| 158 |
+
patterns = [
|
| 159 |
+
"**/*.jsonl",
|
| 160 |
+
"**/*.json",
|
| 161 |
+
"**/*.csv",
|
| 162 |
+
"**/*.tsv",
|
| 163 |
+
"**/*.txt",
|
| 164 |
+
]
|
| 165 |
+
seen = set()
|
| 166 |
+
for pat in patterns:
|
| 167 |
+
for path in input_dir.glob(pat):
|
| 168 |
+
if ".git" in path.parts:
|
| 169 |
+
continue
|
| 170 |
+
lower = str(path).lower()
|
| 171 |
+
if "codeforces" not in lower:
|
| 172 |
+
continue
|
| 173 |
+
if path.name.lower() == "codeforces.jsonl":
|
| 174 |
+
continue
|
| 175 |
+
if path.is_file() and path not in seen:
|
| 176 |
+
seen.add(path)
|
| 177 |
+
yield path
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def ingest_codeforces(input_dir: Path, output_file: Path, logger: logging.Logger) -> Dict[str, int]:
|
| 181 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 182 |
+
extracted = 0
|
| 183 |
+
filtered = 0
|
| 184 |
+
bad = 0
|
| 185 |
+
usable = 0
|
| 186 |
+
deduped = 0
|
| 187 |
+
|
| 188 |
+
files = list(_iter_candidate_files(input_dir))
|
| 189 |
+
if not files:
|
| 190 |
+
logger.warning("No Codeforces raw files found under %s", input_dir.resolve())
|
| 191 |
+
output_file.write_text("", encoding="utf-8")
|
| 192 |
+
return {"extracted": 0, "filtered": 0, "bad": 0, "usable": 0}
|
| 193 |
+
|
| 194 |
+
dedupe_db = output_file.parent / "codeforces_ingest_dedupe.sqlite"
|
| 195 |
+
if dedupe_db.exists():
|
| 196 |
+
dedupe_db.unlink()
|
| 197 |
+
for suffix in ("-wal", "-shm"):
|
| 198 |
+
s = dedupe_db.with_name(dedupe_db.name + suffix)
|
| 199 |
+
if s.exists():
|
| 200 |
+
s.unlink()
|
| 201 |
+
|
| 202 |
+
conn = sqlite3.connect(str(dedupe_db))
|
| 203 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 204 |
+
conn.execute("CREATE TABLE IF NOT EXISTS seen_hashes (h TEXT PRIMARY KEY)")
|
| 205 |
+
|
| 206 |
+
def is_dup(instruction: str, response: str) -> bool:
|
| 207 |
+
import hashlib
|
| 208 |
+
|
| 209 |
+
h = hashlib.sha256(f"{instruction}||{response}".encode("utf-8")).hexdigest()
|
| 210 |
+
try:
|
| 211 |
+
conn.execute("INSERT INTO seen_hashes(h) VALUES (?)", (h,))
|
| 212 |
+
return False
|
| 213 |
+
except sqlite3.IntegrityError:
|
| 214 |
+
return True
|
| 215 |
+
|
| 216 |
+
with output_file.open("w", encoding="utf-8") as out_f:
|
| 217 |
+
for file_path in tqdm(files, desc="codeforces_files", unit="file"):
|
| 218 |
+
suffix = file_path.suffix.lower()
|
| 219 |
+
if suffix == ".jsonl":
|
| 220 |
+
rec_iter = _iter_jsonl(file_path)
|
| 221 |
+
elif suffix == ".json":
|
| 222 |
+
rec_iter = _iter_json_stream(file_path, logger)
|
| 223 |
+
elif suffix == ".csv":
|
| 224 |
+
rec_iter = _iter_csv_like(file_path, ",")
|
| 225 |
+
elif suffix == ".tsv":
|
| 226 |
+
rec_iter = _iter_csv_like(file_path, "\t")
|
| 227 |
+
else:
|
| 228 |
+
rec_iter = _iter_txt_records(file_path)
|
| 229 |
+
|
| 230 |
+
for rec in tqdm(rec_iter, desc=f"ingest:{file_path.name}", unit="rows", leave=False):
|
| 231 |
+
try:
|
| 232 |
+
extracted += 1
|
| 233 |
+
problem, solution = _extract_pair(rec)
|
| 234 |
+
if len(problem) <= 50 or len(solution) <= 20:
|
| 235 |
+
filtered += 1
|
| 236 |
+
continue
|
| 237 |
+
# Keep response as solution code; reject obvious non-code text.
|
| 238 |
+
if not CODE_HINT_RE.search(solution):
|
| 239 |
+
filtered += 1
|
| 240 |
+
continue
|
| 241 |
+
instruction = f"Solve the following problem:\n{problem}"
|
| 242 |
+
if is_dup(instruction, solution):
|
| 243 |
+
deduped += 1
|
| 244 |
+
continue
|
| 245 |
+
row = {"instruction": instruction, "response": solution}
|
| 246 |
+
out_f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 247 |
+
usable += 1
|
| 248 |
+
except Exception:
|
| 249 |
+
bad += 1
|
| 250 |
+
continue
|
| 251 |
+
conn.commit()
|
| 252 |
+
conn.close()
|
| 253 |
+
|
| 254 |
+
skipped = filtered + bad + deduped
|
| 255 |
+
logger.info("Codeforces ingest total_input=%d", extracted)
|
| 256 |
+
logger.info("Codeforces ingest valid_output=%d", usable)
|
| 257 |
+
logger.info("Codeforces ingest skipped=%d", skipped)
|
| 258 |
+
logger.info("Codeforces ingest filtered=%d", filtered)
|
| 259 |
+
logger.info("Codeforces ingest deduped=%d", deduped)
|
| 260 |
+
logger.info("Codeforces ingest bad=%d", bad)
|
| 261 |
+
return {
|
| 262 |
+
"total_input": extracted,
|
| 263 |
+
"valid_output": usable,
|
| 264 |
+
"skipped": skipped,
|
| 265 |
+
"filtered": filtered,
|
| 266 |
+
"deduped": deduped,
|
| 267 |
+
"bad": bad,
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _build_parser() -> argparse.ArgumentParser:
|
| 272 |
+
parser = argparse.ArgumentParser(
|
| 273 |
+
description="Ingest Codeforces problem-solution data into JSONL for MINDI pipeline."
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument("--input-dir", type=Path, default=Path("./data/raw"))
|
| 276 |
+
parser.add_argument("--output", type=Path, default=Path("./data/raw/codeforces.jsonl"))
|
| 277 |
+
parser.add_argument("--log-file", type=Path, default=Path("./logs/codeforces_ingest.log"))
|
| 278 |
+
return parser
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
args = _build_parser().parse_args()
|
| 283 |
+
log = setup_logger(args.log_file)
|
| 284 |
+
stats = ingest_codeforces(args.input_dir, args.output, log)
|
| 285 |
+
print(f"Output: {args.output.resolve()}")
|
| 286 |
+
print(f"Total input: {stats['total_input']}")
|
| 287 |
+
print(f"Valid output: {stats['valid_output']}")
|
| 288 |
+
print(f"Skipped: {stats['skipped']}")
|
backup_step1000.tar.gz → data/final/_rebalance_tmp/instruction.jsonl
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77fd726ec9b3b9135edc4b22c251c760ea444060507315d09e0156a9ad08cff2
|
| 3 |
+
size 359523113
|
backup_step2000.tar.gz → data/final/_rebalance_tmp/problem.jsonl
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dea7bb001629f0e7b72363ea4ffc8da89b21cab0322c80faea3bd4352d2d28cd
|
| 3 |
+
size 392283637
|
backup_step8000.tar.gz → data/final/_rebalance_tmp/rebalance_seen.sqlite
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27065d4515ff3d381d784058a61a997517d89621147f71316b4517997a7567a0
|
| 3 |
+
size 85028864
|
backup_step3000.tar.gz → data/final/_rebalance_tmp/structured.jsonl
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83898474a15e20eba8f8fd4b4163c381d69b87374c4d82b3e289898da0b9f2fc
|
| 3 |
+
size 283810066
|
backup_step8250.tar.gz → data/final/dedupe_hashes.sqlite
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89622c37e7ff270b6745693d3dbd63a25a7937bf7e5261a01d4676568207d7ea
|
| 3 |
+
size 84996096
|
data/final/train.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f78cdcc0171535ad8ac533beccaec3aab78ef94870055bdf7dd5c798d629aef
|
| 3 |
+
size 1867627125
|
data/raw/custom_finetune_pairs.jsonl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7ab1ceab4d5a85de0c15a54f6420c483e78de5db4b5654dc5d34aa1d02893921
|
| 3 |
-
size 451
|
|
|
|
|
|
|
|
|
|
|
|
data_fetch.py
CHANGED
|
@@ -1,222 +1,1077 @@
|
|
| 1 |
import argparse
|
| 2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Dict, List, Optional
|
| 5 |
|
| 6 |
-
from datasets import
|
|
|
|
| 7 |
|
| 8 |
-
from
|
| 9 |
-
from
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
return ""
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
if not instruction or not output:
|
| 23 |
-
return False
|
| 24 |
-
if len(output) < DATA_CONFIG.min_output_chars:
|
| 25 |
-
return False
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
-
def
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
-
def
|
| 45 |
-
if
|
| 46 |
return
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
def _load_or_download(dataset_name: str, cache_path: Path, **kwargs):
|
| 52 |
-
if cache_path.exists():
|
| 53 |
-
return load_from_disk(str(cache_path))
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
-
def
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
rows: List[Dict[str, str]] = []
|
| 71 |
-
split = ds_obj["test"] if isinstance(ds_obj, DatasetDict) else ds_obj
|
| 72 |
-
|
| 73 |
-
for item in split:
|
| 74 |
-
prompt = item.get("prompt", "")
|
| 75 |
-
solution = item.get("canonical_solution", "")
|
| 76 |
-
if "def " not in prompt:
|
| 77 |
-
continue
|
| 78 |
-
rows.append(
|
| 79 |
-
_to_record(
|
| 80 |
-
instruction="Complete the Python function so it satisfies the specification.",
|
| 81 |
-
input_text=prompt,
|
| 82 |
-
output_text=solution,
|
| 83 |
-
)
|
| 84 |
)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
continue
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
)
|
| 114 |
-
if len(rows) >= max_samples:
|
| 115 |
-
return rows
|
| 116 |
-
return rows
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _extract_codesearchnet(ds_obj, max_samples: int) -> List[Dict[str, str]]:
|
| 120 |
-
rows: List[Dict[str, str]] = []
|
| 121 |
-
splits = []
|
| 122 |
-
if isinstance(ds_obj, DatasetDict):
|
| 123 |
-
for split_name in ("train", "validation"):
|
| 124 |
-
if split_name in ds_obj:
|
| 125 |
-
splits.append(ds_obj[split_name])
|
| 126 |
-
else:
|
| 127 |
-
splits = [ds_obj]
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
continue
|
| 139 |
-
|
|
|
|
| 140 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
| 149 |
-
if len(rows) >= max_samples:
|
| 150 |
-
return rows
|
| 151 |
-
return rows
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def _dedupe_and_filter(rows: List[Dict[str, str]], max_total: int) -> List[Dict[str, str]]:
|
| 155 |
-
seen = set()
|
| 156 |
-
clean_rows: List[Dict[str, str]] = []
|
| 157 |
-
for row in rows:
|
| 158 |
-
if not _quality_ok(row):
|
| 159 |
-
continue
|
| 160 |
-
digest = hashlib.sha256(
|
| 161 |
-
f"{row['instruction']}||{row['input']}||{row['output']}".encode("utf-8")
|
| 162 |
-
).hexdigest()
|
| 163 |
-
if digest in seen:
|
| 164 |
-
continue
|
| 165 |
-
seen.add(digest)
|
| 166 |
-
clean_rows.append(row)
|
| 167 |
-
if len(clean_rows) >= max_total:
|
| 168 |
-
break
|
| 169 |
-
return clean_rows
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
ensure_dirs([PATHS.data_dir, PATHS.dataset_cache_dir, PATHS.raw_dataset_dir, PATHS.logs_dir])
|
| 174 |
-
logger = setup_logger("data_fetch", PATHS.logs_dir / "data_fetch.log")
|
| 175 |
|
| 176 |
-
logger.info("Loading datasets (offline_only=%s).", offline_only)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
else:
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
rows.extend(_extract_codesearchnet(csn_ds, DATA_CONFIG.max_codesearchnet_samples))
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
logger.info("Saved %d cleaned training rows to %s", len(clean_rows), PATHS.train_jsonl)
|
| 205 |
-
print(f"Saved dataset: {PATHS.train_jsonl.resolve()}")
|
| 206 |
-
return PATHS.train_jsonl
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
|
| 210 |
-
parser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
parser.add_argument(
|
| 212 |
-
"--
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
)
|
| 216 |
return parser
|
| 217 |
|
| 218 |
|
| 219 |
if __name__ == "__main__":
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sqlite3
|
| 6 |
+
from collections import Counter
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Dict, Iterable, List, Optional
|
| 9 |
|
| 10 |
+
from datasets import load_dataset, load_from_disk
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
|
| 13 |
+
from dataset_cleaner import build_balanced_dataset, clean_record
|
| 14 |
+
from dataset_formatter import build_instruction_sample
|
| 15 |
+
from utils import ensure_dirs, setup_logger
|
| 16 |
|
| 17 |
|
| 18 |
+
RAW_DIR = Path("./data/raw")
|
| 19 |
+
FINAL_DIR = Path("./data/final")
|
| 20 |
+
FINAL_TRAIN = FINAL_DIR / "train.jsonl"
|
| 21 |
+
LOG_DIR = Path("./logs")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _safe_get(item: Dict[str, object], keys: Iterable[str]) -> str:
|
| 25 |
+
for key in keys:
|
| 26 |
+
value = item.get(key)
|
| 27 |
+
if value:
|
| 28 |
+
return str(value)
|
| 29 |
+
return ""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _write_jsonl(path: Path, rows: Iterable[Dict[str, str]]) -> int:
|
| 33 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
count = 0
|
| 35 |
+
with path.open("w", encoding="utf-8") as f:
|
| 36 |
+
for row in rows:
|
| 37 |
+
if not row.get("instruction") or not row.get("response"):
|
| 38 |
+
continue
|
| 39 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 40 |
+
count += 1
|
| 41 |
+
return count
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _iter_jsonl(path: Path) -> Iterable[Dict[str, object]]:
|
| 45 |
+
with path.open("r", encoding="utf-8") as f:
|
| 46 |
+
for line in f:
|
| 47 |
+
line = line.strip()
|
| 48 |
+
if not line:
|
| 49 |
+
continue
|
| 50 |
+
try:
|
| 51 |
+
yield json.loads(line)
|
| 52 |
+
except json.JSONDecodeError:
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _source_to_category(source_name: str) -> str:
|
| 57 |
+
s = source_name.lower()
|
| 58 |
+
if any(k in s for k in ("codealpaca", "evol", "ultrachat", "openhermes", "orca")):
|
| 59 |
+
return "instruction"
|
| 60 |
+
if any(
|
| 61 |
+
k in s
|
| 62 |
+
for k in (
|
| 63 |
+
"leetcode",
|
| 64 |
+
"contest",
|
| 65 |
+
"problem",
|
| 66 |
+
"mbpp",
|
| 67 |
+
"humaneval",
|
| 68 |
+
"apps",
|
| 69 |
+
"codeforces",
|
| 70 |
+
"codesearchnet_problem",
|
| 71 |
+
)
|
| 72 |
+
):
|
| 73 |
+
return "problem"
|
| 74 |
+
return "structured"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _decode_text(value) -> str:
|
| 78 |
+
if value is None:
|
| 79 |
return ""
|
| 80 |
+
if isinstance(value, str):
|
| 81 |
+
return value
|
| 82 |
+
if isinstance(value, bytes):
|
| 83 |
+
return value.decode("utf-8", errors="ignore")
|
| 84 |
+
return str(value)
|
| 85 |
|
| 86 |
|
| 87 |
+
def _extract_solution_from_code_contests(item: Dict[str, object]) -> str:
|
| 88 |
+
sols = item.get("solutions")
|
| 89 |
+
if isinstance(sols, dict):
|
| 90 |
+
# Typical schema: {"language": [...], "solution": [bytes...]}
|
| 91 |
+
cand = sols.get("solution")
|
| 92 |
+
if isinstance(cand, list):
|
| 93 |
+
# Prefer Python-looking snippets when possible.
|
| 94 |
+
for s in cand:
|
| 95 |
+
t = _decode_text(s)
|
| 96 |
+
if re.search(r"\bdef\b|\bimport\b|\bprint\(", t):
|
| 97 |
+
return t
|
| 98 |
+
if cand:
|
| 99 |
+
return _decode_text(cand[0])
|
| 100 |
+
if isinstance(sols, list) and sols:
|
| 101 |
+
return _decode_text(sols[0])
|
| 102 |
+
return _safe_get(item, ["solution", "answer", "code"])
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
def _extract_many_code_contests_solutions(item: Dict[str, object], max_per_problem: int = 6) -> List[str]:
|
| 106 |
+
out: List[str] = []
|
| 107 |
+
sols = item.get("solutions")
|
| 108 |
+
if isinstance(sols, dict):
|
| 109 |
+
cand = sols.get("solution")
|
| 110 |
+
if isinstance(cand, list):
|
| 111 |
+
for s in cand:
|
| 112 |
+
t = _decode_text(s).strip()
|
| 113 |
+
if not t:
|
| 114 |
+
continue
|
| 115 |
+
if t not in out:
|
| 116 |
+
out.append(t)
|
| 117 |
+
if len(out) >= max_per_problem:
|
| 118 |
+
break
|
| 119 |
+
if not out:
|
| 120 |
+
one = _extract_solution_from_code_contests(item).strip()
|
| 121 |
+
if one:
|
| 122 |
+
out.append(one)
|
| 123 |
+
return out
|
| 124 |
|
| 125 |
|
| 126 |
+
def _extract_many_apps_solutions(item: Dict[str, object], max_per_problem: int = 5) -> List[str]:
|
| 127 |
+
out: List[str] = []
|
| 128 |
+
for key in ("solutions", "solution", "answer", "code"):
|
| 129 |
+
val = item.get(key)
|
| 130 |
+
if isinstance(val, list):
|
| 131 |
+
for x in val:
|
| 132 |
+
t = _decode_text(x).strip()
|
| 133 |
+
if t and t not in out:
|
| 134 |
+
out.append(t)
|
| 135 |
+
if len(out) >= max_per_problem:
|
| 136 |
+
return out
|
| 137 |
+
elif isinstance(val, dict):
|
| 138 |
+
for x in val.values():
|
| 139 |
+
if isinstance(x, list):
|
| 140 |
+
for y in x:
|
| 141 |
+
t = _decode_text(y).strip()
|
| 142 |
+
if t and t not in out:
|
| 143 |
+
out.append(t)
|
| 144 |
+
if len(out) >= max_per_problem:
|
| 145 |
+
return out
|
| 146 |
+
else:
|
| 147 |
+
t = _decode_text(val).strip()
|
| 148 |
+
if t and t not in out:
|
| 149 |
+
out.append(t)
|
| 150 |
+
if len(out) >= max_per_problem:
|
| 151 |
+
return out
|
| 152 |
+
return out
|
| 153 |
|
| 154 |
|
| 155 |
+
def _collect_code_candidates(value, out: List[str], max_per_problem: int) -> None:
|
| 156 |
+
if len(out) >= max_per_problem:
|
| 157 |
return
|
| 158 |
+
if value is None:
|
| 159 |
+
return
|
| 160 |
+
if isinstance(value, str):
|
| 161 |
+
v = value.strip()
|
| 162 |
+
if v and v not in out:
|
| 163 |
+
out.append(v)
|
| 164 |
+
return
|
| 165 |
+
if isinstance(value, bytes):
|
| 166 |
+
v = _decode_text(value).strip()
|
| 167 |
+
if v and v not in out:
|
| 168 |
+
out.append(v)
|
| 169 |
+
return
|
| 170 |
+
if isinstance(value, list):
|
| 171 |
+
for x in value:
|
| 172 |
+
_collect_code_candidates(x, out, max_per_problem)
|
| 173 |
+
if len(out) >= max_per_problem:
|
| 174 |
+
return
|
| 175 |
+
return
|
| 176 |
+
if isinstance(value, dict):
|
| 177 |
+
for k in ("solution", "solutions", "code", "answer", "python", "cpp", "java", "javascript"):
|
| 178 |
+
if k in value:
|
| 179 |
+
_collect_code_candidates(value.get(k), out, max_per_problem)
|
| 180 |
+
if len(out) >= max_per_problem:
|
| 181 |
+
return
|
| 182 |
+
for v in value.values():
|
| 183 |
+
_collect_code_candidates(v, out, max_per_problem)
|
| 184 |
+
if len(out) >= max_per_problem:
|
| 185 |
+
return
|
| 186 |
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
def _extract_many_generic_solutions(
|
| 189 |
+
item: Dict[str, object],
|
| 190 |
+
*,
|
| 191 |
+
max_per_problem: int = 6,
|
| 192 |
+
) -> List[str]:
|
| 193 |
+
out: List[str] = []
|
| 194 |
+
for key in ("solutions", "solution", "code", "answer", "python", "cpp", "java", "javascript"):
|
| 195 |
+
_collect_code_candidates(item.get(key), out, max_per_problem)
|
| 196 |
+
if len(out) >= max_per_problem:
|
| 197 |
+
break
|
| 198 |
+
return out
|
| 199 |
|
| 200 |
|
| 201 |
+
def _compute_targets(target_size: int, min_problem_samples: int) -> Dict[str, int]:
|
| 202 |
+
instruction_target = int(target_size * 0.60)
|
| 203 |
+
structured_target = int(target_size * 0.30)
|
| 204 |
+
problem_target = target_size - instruction_target - structured_target
|
| 205 |
+
problem_target = max(problem_target, min_problem_samples)
|
| 206 |
+
remainder = target_size - problem_target
|
| 207 |
+
if remainder < 0:
|
| 208 |
+
raise RuntimeError(
|
| 209 |
+
f"Invalid target sizing: min_problem_samples={min_problem_samples} exceeds "
|
| 210 |
+
f"target_size={target_size}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
)
|
| 212 |
+
instruction_target = int(remainder * (60.0 / 90.0))
|
| 213 |
+
structured_target = remainder - instruction_target
|
| 214 |
+
return {
|
| 215 |
+
"instruction": instruction_target,
|
| 216 |
+
"structured": structured_target,
|
| 217 |
+
"problem": problem_target,
|
| 218 |
+
}
|
| 219 |
|
| 220 |
|
| 221 |
+
def rebalance_final_dataset(
|
| 222 |
+
*,
|
| 223 |
+
raw_paths: List[Path],
|
| 224 |
+
output_path: Path,
|
| 225 |
+
target_size: int,
|
| 226 |
+
min_tokens: int,
|
| 227 |
+
max_tokens: int,
|
| 228 |
+
min_problem_samples: int,
|
| 229 |
+
logger,
|
| 230 |
+
) -> Dict[str, object]:
|
| 231 |
+
# Post-build rebalance using streaming + temp shards, then exact down/upsample.
|
| 232 |
+
tmp_dir = output_path.parent / "_rebalance_tmp"
|
| 233 |
+
ensure_dirs([tmp_dir])
|
| 234 |
+
|
| 235 |
+
shard_paths = {
|
| 236 |
+
"instruction": tmp_dir / "instruction.jsonl",
|
| 237 |
+
"structured": tmp_dir / "structured.jsonl",
|
| 238 |
+
"problem": tmp_dir / "problem.jsonl",
|
| 239 |
+
}
|
| 240 |
+
for p in shard_paths.values():
|
| 241 |
+
if p.exists():
|
| 242 |
+
p.unlink()
|
| 243 |
+
|
| 244 |
+
dedupe_db = tmp_dir / "rebalance_seen.sqlite"
|
| 245 |
+
if dedupe_db.exists():
|
| 246 |
+
dedupe_db.unlink()
|
| 247 |
+
for suffix in ("-wal", "-shm"):
|
| 248 |
+
side = dedupe_db.with_name(dedupe_db.name + suffix)
|
| 249 |
+
if side.exists():
|
| 250 |
+
side.unlink()
|
| 251 |
+
|
| 252 |
+
conn = sqlite3.connect(str(dedupe_db))
|
| 253 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 254 |
+
conn.execute("CREATE TABLE IF NOT EXISTS seen_hashes (h TEXT PRIMARY KEY)")
|
| 255 |
+
|
| 256 |
+
def is_dup(instruction: str, response: str) -> bool:
|
| 257 |
+
import hashlib
|
| 258 |
+
|
| 259 |
+
h = hashlib.sha256(f"{instruction}||{response}".encode("utf-8")).hexdigest()
|
| 260 |
+
try:
|
| 261 |
+
conn.execute("INSERT INTO seen_hashes(h) VALUES (?)", (h,))
|
| 262 |
+
return False
|
| 263 |
+
except sqlite3.IntegrityError:
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
shard_counts = Counter()
|
| 267 |
+
with (
|
| 268 |
+
shard_paths["instruction"].open("w", encoding="utf-8") as f_inst,
|
| 269 |
+
shard_paths["structured"].open("w", encoding="utf-8") as f_struct,
|
| 270 |
+
shard_paths["problem"].open("w", encoding="utf-8") as f_prob,
|
| 271 |
+
):
|
| 272 |
+
writers = {
|
| 273 |
+
"instruction": f_inst,
|
| 274 |
+
"structured": f_struct,
|
| 275 |
+
"problem": f_prob,
|
| 276 |
+
}
|
| 277 |
+
for raw_path in raw_paths:
|
| 278 |
+
if not raw_path.exists():
|
| 279 |
continue
|
| 280 |
+
src_default = raw_path.stem
|
| 281 |
+
for rec in tqdm(_iter_jsonl(raw_path), desc=f"rebalance_scan:{raw_path.name}", unit="rows"):
|
| 282 |
+
if "_source" not in rec:
|
| 283 |
+
rec["_source"] = src_default
|
| 284 |
+
if "_category" not in rec:
|
| 285 |
+
rec["_category"] = _source_to_category(src_default)
|
| 286 |
+
cleaned = clean_record(rec, min_tokens=min_tokens, max_tokens=max_tokens)
|
| 287 |
+
if cleaned is None:
|
| 288 |
+
continue
|
| 289 |
+
if is_dup(cleaned["instruction"], cleaned["response"]):
|
| 290 |
+
continue
|
| 291 |
+
cat = cleaned["_category"]
|
| 292 |
+
if cat not in writers:
|
| 293 |
+
cat = _source_to_category(cleaned.get("_source", ""))
|
| 294 |
+
line_obj = {
|
| 295 |
+
"instruction": cleaned["instruction"],
|
| 296 |
+
"response": cleaned["response"],
|
| 297 |
+
"_source": cleaned["_source"],
|
| 298 |
+
"_category": cat,
|
| 299 |
+
}
|
| 300 |
+
writers[cat].write(json.dumps(line_obj, ensure_ascii=False) + "\n")
|
| 301 |
+
shard_counts[cat] += 1
|
| 302 |
+
conn.commit()
|
| 303 |
+
conn.close()
|
| 304 |
+
|
| 305 |
+
targets = _compute_targets(target_size=target_size, min_problem_samples=min_problem_samples)
|
| 306 |
+
logger.info("Rebalance targets: %s (available=%s)", targets, dict(shard_counts))
|
| 307 |
+
|
| 308 |
+
source_breakdown = Counter()
|
| 309 |
+
category_breakdown = Counter()
|
| 310 |
+
total_tokens = 0
|
| 311 |
+
total_samples = 0
|
| 312 |
+
problem_real_count = 0
|
| 313 |
+
problem_synthetic_count = 0
|
| 314 |
+
max_synth_problem = int(targets["problem"] * 0.30)
|
| 315 |
+
|
| 316 |
+
def write_from_shard(cat: str, needed: int, out_f) -> int:
|
| 317 |
+
nonlocal total_samples, total_tokens, problem_real_count, problem_synthetic_count
|
| 318 |
+
written = 0
|
| 319 |
+
shard = shard_paths[cat]
|
| 320 |
+
if not shard.exists():
|
| 321 |
+
return 0
|
| 322 |
+
with shard.open("r", encoding="utf-8") as f:
|
| 323 |
+
for line in f:
|
| 324 |
+
if written >= needed:
|
| 325 |
+
break
|
| 326 |
+
obj = json.loads(line)
|
| 327 |
+
src = obj.get("_source", "unknown")
|
| 328 |
+
is_problem_synth = cat == "problem" and "codesearchnet_problem_fallback" in src
|
| 329 |
+
if is_problem_synth and problem_synthetic_count >= max_synth_problem:
|
| 330 |
+
continue
|
| 331 |
+
out_f.write(
|
| 332 |
+
json.dumps(
|
| 333 |
+
{"instruction": obj["instruction"], "response": obj["response"]},
|
| 334 |
+
ensure_ascii=False,
|
| 335 |
+
)
|
| 336 |
+
+ "\n"
|
| 337 |
)
|
| 338 |
+
written += 1
|
| 339 |
+
total_samples += 1
|
| 340 |
+
category_breakdown[cat] += 1
|
| 341 |
+
source_breakdown[src] += 1
|
| 342 |
+
if cat == "problem":
|
| 343 |
+
if is_problem_synth:
|
| 344 |
+
problem_synthetic_count += 1
|
| 345 |
+
else:
|
| 346 |
+
problem_real_count += 1
|
| 347 |
+
total_tokens += len((obj["instruction"] + " " + obj["response"]).split())
|
| 348 |
+
return written
|
| 349 |
+
|
| 350 |
+
def upsample_shard(cat: str, needed: int, out_f) -> int:
|
| 351 |
+
nonlocal total_samples, total_tokens, problem_real_count, problem_synthetic_count
|
| 352 |
+
shard = shard_paths[cat]
|
| 353 |
+
if not shard.exists() or needed <= 0:
|
| 354 |
+
return 0
|
| 355 |
+
written = 0
|
| 356 |
+
while written < needed:
|
| 357 |
+
made_progress = 0
|
| 358 |
+
with shard.open("r", encoding="utf-8") as f:
|
| 359 |
+
for line in f:
|
| 360 |
+
if written >= needed:
|
| 361 |
+
break
|
| 362 |
+
obj = json.loads(line)
|
| 363 |
+
src = obj.get("_source", "unknown")
|
| 364 |
+
is_problem_synth = cat == "problem" and "codesearchnet_problem_fallback" in src
|
| 365 |
+
if is_problem_synth and problem_synthetic_count >= max_synth_problem:
|
| 366 |
+
continue
|
| 367 |
+
out_f.write(
|
| 368 |
+
json.dumps(
|
| 369 |
+
{"instruction": obj["instruction"], "response": obj["response"]},
|
| 370 |
+
ensure_ascii=False,
|
| 371 |
+
)
|
| 372 |
+
+ "\n"
|
| 373 |
+
)
|
| 374 |
+
written += 1
|
| 375 |
+
made_progress += 1
|
| 376 |
+
total_samples += 1
|
| 377 |
+
category_breakdown[cat] += 1
|
| 378 |
+
source_breakdown[src] += 1
|
| 379 |
+
if cat == "problem":
|
| 380 |
+
if is_problem_synth:
|
| 381 |
+
problem_synthetic_count += 1
|
| 382 |
+
else:
|
| 383 |
+
problem_real_count += 1
|
| 384 |
+
total_tokens += len((obj["instruction"] + " " + obj["response"]).split())
|
| 385 |
+
if made_progress == 0:
|
| 386 |
+
break
|
| 387 |
+
return written
|
| 388 |
+
|
| 389 |
+
with output_path.open("w", encoding="utf-8") as out_f:
|
| 390 |
+
for cat in ("instruction", "structured", "problem"):
|
| 391 |
+
want = targets[cat]
|
| 392 |
+
got = write_from_shard(cat, want, out_f)
|
| 393 |
+
if got < want:
|
| 394 |
+
deficit = want - got
|
| 395 |
+
if cat == "problem":
|
| 396 |
+
logger.warning(
|
| 397 |
+
"Category %s shortfall: need=%d got=%d (no upsampling allowed for problem).",
|
| 398 |
+
cat,
|
| 399 |
+
want,
|
| 400 |
+
got,
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
upsampled = upsample_shard(cat, deficit, out_f)
|
| 404 |
+
logger.warning(
|
| 405 |
+
"Category %s shortfall: need=%d got=%d upsampled=%d",
|
| 406 |
+
cat,
|
| 407 |
+
want,
|
| 408 |
+
got,
|
| 409 |
+
upsampled,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
inst = category_breakdown["instruction"]
|
| 413 |
+
struct = category_breakdown["structured"]
|
| 414 |
+
problem = category_breakdown["problem"]
|
| 415 |
+
instruction_vs_raw = {
|
| 416 |
+
"instruction_pct": round(100.0 * inst / max(total_samples, 1), 2),
|
| 417 |
+
"raw_converted_pct": round(100.0 * (struct + problem) / max(total_samples, 1), 2),
|
| 418 |
+
}
|
| 419 |
+
avg_len = round(total_tokens / max(total_samples, 1), 2)
|
| 420 |
+
|
| 421 |
+
return {
|
| 422 |
+
"total_samples": total_samples,
|
| 423 |
+
"avg_length_tokens": avg_len,
|
| 424 |
+
"source_breakdown": dict(source_breakdown),
|
| 425 |
+
"category_breakdown": dict(category_breakdown),
|
| 426 |
+
"instruction_vs_raw_ratio": instruction_vs_raw,
|
| 427 |
+
"targets": targets,
|
| 428 |
+
"problem_real_count": problem_real_count,
|
| 429 |
+
"problem_synthetic_count": problem_synthetic_count,
|
| 430 |
+
"problem_synthetic_pct": round(
|
| 431 |
+
100.0 * problem_synthetic_count / max(problem_real_count + problem_synthetic_count, 1), 2
|
| 432 |
+
),
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _try_load_dataset(candidates: List[Dict[str, object]], logger):
|
| 437 |
+
last_exc: Optional[Exception] = None
|
| 438 |
+
for cand in candidates:
|
| 439 |
+
try:
|
| 440 |
+
ds = load_dataset(**cand)
|
| 441 |
+
logger.info("Loaded dataset: %s", cand)
|
| 442 |
+
return ds
|
| 443 |
+
except Exception as exc:
|
| 444 |
+
logger.warning("Dataset load failed for %s: %s", cand, exc)
|
| 445 |
+
last_exc = exc
|
| 446 |
+
if last_exc:
|
| 447 |
+
raise last_exc
|
| 448 |
+
raise RuntimeError("No dataset candidates provided.")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def fetch_instruction_codealpaca(raw_path: Path, limit: int, logger) -> int:
|
| 452 |
+
ds = _try_load_dataset(
|
| 453 |
+
[
|
| 454 |
+
{"path": "sahil2801/CodeAlpaca-20k", "split": "train"},
|
| 455 |
+
{"path": "HuggingFaceH4/CodeAlpaca_20K", "split": "train"},
|
| 456 |
+
],
|
| 457 |
+
logger,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def rows():
|
| 461 |
+
emitted = 0
|
| 462 |
+
for item in tqdm(ds, desc="codealpaca", unit="rows"):
|
| 463 |
+
if emitted >= limit:
|
| 464 |
+
break
|
| 465 |
+
instruction = _safe_get(item, ["instruction"])
|
| 466 |
+
inp = _safe_get(item, ["input"])
|
| 467 |
+
output = _safe_get(item, ["output", "response", "answer"])
|
| 468 |
+
if inp:
|
| 469 |
+
instruction = f"{instruction}\n\nInput:\n{inp}".strip()
|
| 470 |
+
emitted += 1
|
| 471 |
+
yield build_instruction_sample(
|
| 472 |
+
instruction=instruction,
|
| 473 |
+
response=output,
|
| 474 |
+
source="codealpaca",
|
| 475 |
+
category="instruction",
|
| 476 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
+
return _write_jsonl(raw_path, rows())
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def fetch_instruction_evol(raw_path: Path, limit: int, logger) -> int:
|
| 482 |
+
ds = _try_load_dataset(
|
| 483 |
+
[
|
| 484 |
+
{"path": "nickrosh/Evol-Instruct-Code-80k-v1", "split": "train"},
|
| 485 |
+
{"path": "WizardLMTeam/WizardCoder-Evol-Instruct-V2-196k", "split": "train"},
|
| 486 |
+
{"path": "ise-uiuc/Magicoder-OSS-Instruct-75K", "split": "train"},
|
| 487 |
+
],
|
| 488 |
+
logger,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
def rows():
|
| 492 |
+
emitted = 0
|
| 493 |
+
for item in tqdm(ds, desc="evol_instruct_code", unit="rows"):
|
| 494 |
+
if emitted >= limit:
|
| 495 |
+
break
|
| 496 |
+
instruction = _safe_get(item, ["instruction", "prompt", "question"])
|
| 497 |
+
inp = _safe_get(item, ["input"])
|
| 498 |
+
output = _safe_get(item, ["output", "response", "answer"])
|
| 499 |
+
if inp:
|
| 500 |
+
instruction = f"{instruction}\n\nInput:\n{inp}".strip()
|
| 501 |
+
emitted += 1
|
| 502 |
+
yield build_instruction_sample(
|
| 503 |
+
instruction=instruction,
|
| 504 |
+
response=output,
|
| 505 |
+
source="evol_instruct_code",
|
| 506 |
+
category="instruction",
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
return _write_jsonl(raw_path, rows())
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def fetch_instruction_ultrachat_code(raw_path: Path, limit: int, logger) -> int:
|
| 513 |
+
ds = _try_load_dataset(
|
| 514 |
+
[
|
| 515 |
+
{"path": "HuggingFaceH4/ultrachat_200k", "split": "train_sft"},
|
| 516 |
+
{"path": "stingning/ultrachat", "split": "train"},
|
| 517 |
+
],
|
| 518 |
+
logger,
|
| 519 |
+
)
|
| 520 |
+
code_terms = ("python", "javascript", "typescript", "java", "code", "api", "backend", "frontend")
|
| 521 |
+
|
| 522 |
+
def rows():
|
| 523 |
+
emitted = 0
|
| 524 |
+
for item in tqdm(ds, desc="ultrachat_code", unit="rows"):
|
| 525 |
+
if emitted >= limit:
|
| 526 |
+
break
|
| 527 |
+
msgs = item.get("messages") or item.get("conversation") or item.get("conversations")
|
| 528 |
+
if not isinstance(msgs, list) or len(msgs) < 2:
|
| 529 |
+
continue
|
| 530 |
+
user = ""
|
| 531 |
+
assistant = ""
|
| 532 |
+
for msg in msgs:
|
| 533 |
+
if not isinstance(msg, dict):
|
| 534 |
+
continue
|
| 535 |
+
role = str(msg.get("role", "")).lower()
|
| 536 |
+
content = str(msg.get("content", "")).strip()
|
| 537 |
+
if role in {"user", "human"} and not user:
|
| 538 |
+
user = content
|
| 539 |
+
if role in {"assistant", "gpt"} and user and not assistant:
|
| 540 |
+
assistant = content
|
| 541 |
+
break
|
| 542 |
+
if not user or not assistant:
|
| 543 |
+
continue
|
| 544 |
+
low = (user + " " + assistant).lower()
|
| 545 |
+
if not any(term in low for term in code_terms):
|
| 546 |
continue
|
| 547 |
+
emitted += 1
|
| 548 |
+
yield build_instruction_sample(
|
| 549 |
+
instruction=user,
|
| 550 |
+
response=assistant,
|
| 551 |
+
source="ultrachat_code",
|
| 552 |
+
category="instruction",
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
return _write_jsonl(raw_path, rows())
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def fetch_instruction_openhermes_code(raw_path: Path, limit: int, logger) -> int:
|
| 559 |
+
ds = _try_load_dataset(
|
| 560 |
+
[
|
| 561 |
+
{"path": "teknium/OpenHermes-2.5", "split": "train"},
|
| 562 |
+
{"path": "Open-Orca/OpenOrca", "split": "train"},
|
| 563 |
+
],
|
| 564 |
+
logger,
|
| 565 |
+
)
|
| 566 |
+
code_terms = ("python", "javascript", "typescript", "java", "code", "function", "api", "fastapi")
|
| 567 |
|
| 568 |
+
def rows():
|
| 569 |
+
emitted = 0
|
| 570 |
+
for item in tqdm(ds, desc="openhermes_code", unit="rows"):
|
| 571 |
+
if emitted >= limit:
|
| 572 |
+
break
|
| 573 |
+
instruction = _safe_get(item, ["instruction", "question", "prompt"])
|
| 574 |
+
response = _safe_get(item, ["output", "response", "answer"])
|
| 575 |
+
if (not instruction or not response) and isinstance(item.get("conversations"), list):
|
| 576 |
+
user = ""
|
| 577 |
+
assistant = ""
|
| 578 |
+
for msg in item.get("conversations"):
|
| 579 |
+
if not isinstance(msg, dict):
|
| 580 |
+
continue
|
| 581 |
+
from_role = str(msg.get("from", "")).lower()
|
| 582 |
+
value = str(msg.get("value", "")).strip()
|
| 583 |
+
if from_role in {"human", "user"} and not user:
|
| 584 |
+
user = value
|
| 585 |
+
if from_role in {"gpt", "assistant"} and user and not assistant:
|
| 586 |
+
assistant = value
|
| 587 |
+
break
|
| 588 |
+
instruction = instruction or user
|
| 589 |
+
response = response or assistant
|
| 590 |
+
if not instruction or not response:
|
| 591 |
continue
|
| 592 |
+
low = (instruction + " " + response).lower()
|
| 593 |
+
if not any(term in low for term in code_terms):
|
| 594 |
continue
|
| 595 |
+
emitted += 1
|
| 596 |
+
yield build_instruction_sample(
|
| 597 |
+
instruction=instruction,
|
| 598 |
+
response=response,
|
| 599 |
+
source="openhermes_code",
|
| 600 |
+
category="instruction",
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
return _write_jsonl(raw_path, rows())
|
| 604 |
+
|
| 605 |
|
| 606 |
+
def fetch_structured_codesearchnet(raw_path: Path, limit: int, logger) -> int:
|
| 607 |
+
languages = ["python", "javascript", "java"]
|
| 608 |
+
per_lang = max(1, limit // max(1, len(languages)))
|
| 609 |
+
|
| 610 |
+
def rows():
|
| 611 |
+
emitted = 0
|
| 612 |
+
for lang in languages:
|
| 613 |
+
if emitted >= limit:
|
| 614 |
+
break
|
| 615 |
+
ds = None
|
| 616 |
+
cache_by_lang = Path(f"./data/cache/raw/code_search_net_{lang}")
|
| 617 |
+
if cache_by_lang.exists():
|
| 618 |
+
try:
|
| 619 |
+
ds = load_from_disk(str(cache_by_lang))["train"]
|
| 620 |
+
logger.info("Loaded cached CodeSearchNet language=%s from %s", lang, cache_by_lang)
|
| 621 |
+
except Exception as exc:
|
| 622 |
+
logger.warning("Failed cached CodeSearchNet for %s: %s", lang, exc)
|
| 623 |
+
if ds is None:
|
| 624 |
+
try:
|
| 625 |
+
ds = load_dataset("code_search_net", lang, split="train", streaming=True)
|
| 626 |
+
logger.info("Loaded streamed CodeSearchNet language=%s", lang)
|
| 627 |
+
except Exception as exc:
|
| 628 |
+
logger.warning("Skipping CodeSearchNet language=%s: %s", lang, exc)
|
| 629 |
+
continue
|
| 630 |
+
|
| 631 |
+
lang_count = 0
|
| 632 |
+
for item in tqdm(ds, desc=f"codesearchnet_{lang}", unit="rows"):
|
| 633 |
+
if emitted >= limit or lang_count >= per_lang:
|
| 634 |
+
break
|
| 635 |
+
code = _safe_get(item, ["whole_func_string", "code"])
|
| 636 |
+
path = _safe_get(item, ["path", "func_name"])
|
| 637 |
+
doc = _safe_get(item, ["docstring", "func_documentation_string"])
|
| 638 |
+
if not code:
|
| 639 |
+
continue
|
| 640 |
+
emitted += 1
|
| 641 |
+
lang_count += 1
|
| 642 |
+
yield build_instruction_sample(
|
| 643 |
+
code=code,
|
| 644 |
+
instruction=doc,
|
| 645 |
+
language=lang,
|
| 646 |
+
path=path,
|
| 647 |
+
source=f"codesearchnet_{lang}",
|
| 648 |
+
category="structured",
|
| 649 |
)
|
| 650 |
+
|
| 651 |
+
return _write_jsonl(raw_path, rows())
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def fetch_structured_github_functions(raw_path: Path, limit: int, logger) -> int:
|
| 655 |
+
ds = None
|
| 656 |
+
cache_path = Path("./data/cache/raw/code_search_net_python")
|
| 657 |
+
if cache_path.exists():
|
| 658 |
+
ds = load_from_disk(str(cache_path))["train"]
|
| 659 |
+
logger.info("Using cached GitHub function corpus from %s", cache_path.resolve())
|
| 660 |
+
else:
|
| 661 |
+
ds = load_dataset("code_search_net", "python", split="train", streaming=True)
|
| 662 |
+
logger.info("Using streamed CodeSearchNet python as GitHub-curated function source.")
|
| 663 |
+
|
| 664 |
+
def rows():
|
| 665 |
+
emitted = 0
|
| 666 |
+
for item in tqdm(ds, desc="github_curated_functions", unit="rows"):
|
| 667 |
+
if emitted >= limit:
|
| 668 |
+
break
|
| 669 |
+
code = _safe_get(item, ["whole_func_string", "code", "content"])
|
| 670 |
+
path = _safe_get(item, ["path", "func_name"])
|
| 671 |
+
repo = _safe_get(item, ["repo", "repository_name"])
|
| 672 |
+
doc = _safe_get(item, ["docstring", "func_documentation_string"])
|
| 673 |
+
if not code:
|
| 674 |
+
continue
|
| 675 |
+
title = f"{repo}/{path}" if repo and path else path
|
| 676 |
+
emitted += 1
|
| 677 |
+
yield build_instruction_sample(
|
| 678 |
+
code=code,
|
| 679 |
+
instruction=doc,
|
| 680 |
+
language="python",
|
| 681 |
+
path=path,
|
| 682 |
+
title=title,
|
| 683 |
+
source="github_curated_functions",
|
| 684 |
+
category="structured",
|
| 685 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
|
| 687 |
+
return _write_jsonl(raw_path, rows())
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def fetch_problem_leetcode(raw_path: Path, limit: int, logger) -> int:
|
| 691 |
+
def rows():
|
| 692 |
+
emitted = 0
|
| 693 |
+
synth_emitted = 0
|
| 694 |
+
candidates = [
|
| 695 |
+
("greengerong/leetcode", {"path": "greengerong/leetcode", "split": "train"}),
|
| 696 |
+
("deepmind/code_contests", {"path": "deepmind/code_contests", "split": "train"}),
|
| 697 |
+
("codeparrot/apps", {"path": "codeparrot/apps", "split": "train"}),
|
| 698 |
+
("google-research-datasets/mbpp", {"path": "google-research-datasets/mbpp", "split": "train"}),
|
| 699 |
+
("openai_humaneval", {"path": "openai_humaneval", "split": "test"}),
|
| 700 |
+
# Streamed high-volume real problem source; avoid full git clone.
|
| 701 |
+
("open-r1/codeforces", {"path": "open-r1/codeforces", "split": "train", "streaming": True}),
|
| 702 |
+
]
|
| 703 |
+
|
| 704 |
+
# Optional local codeforces/problem-solution JSONL fallback.
|
| 705 |
+
local_problem_files = sorted(RAW_DIR.glob("codeforces*.jsonl")) + sorted(
|
| 706 |
+
RAW_DIR.glob("problem_solution*.jsonl")
|
| 707 |
+
)
|
| 708 |
+
if not local_problem_files:
|
| 709 |
+
logger.warning(
|
| 710 |
+
"Codeforces dataset missing – recommended for production quality."
|
| 711 |
+
)
|
| 712 |
+
for local_file in local_problem_files:
|
| 713 |
+
if emitted >= limit:
|
| 714 |
+
break
|
| 715 |
+
for item in tqdm(_iter_jsonl(local_file), desc=f"problem_local:{local_file.name}", unit="rows"):
|
| 716 |
+
if emitted >= limit:
|
| 717 |
+
break
|
| 718 |
+
problem = _safe_get(item, ["problem", "instruction", "statement", "question"])
|
| 719 |
+
solution = _safe_get(item, ["solution", "response", "answer", "code"])
|
| 720 |
+
if not problem or not solution:
|
| 721 |
+
continue
|
| 722 |
+
emitted += 1
|
| 723 |
+
yield build_instruction_sample(
|
| 724 |
+
instruction=f"Solve the following problem:\n\n{problem}",
|
| 725 |
+
response=solution,
|
| 726 |
+
source="codeforces_local",
|
| 727 |
+
category="problem",
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
for source_name, cand in candidates:
|
| 731 |
+
if emitted >= limit:
|
| 732 |
+
break
|
| 733 |
+
try:
|
| 734 |
+
ds = load_dataset(**cand)
|
| 735 |
+
logger.info("Loaded problem dataset: %s", cand)
|
| 736 |
+
except Exception as exc:
|
| 737 |
+
logger.warning("Problem dataset load failed for %s: %s", cand, exc)
|
| 738 |
+
if source_name == "codeparrot/apps":
|
| 739 |
+
apps_local = sorted(RAW_DIR.glob("apps*.jsonl")) + sorted(RAW_DIR.glob("apps*.json"))
|
| 740 |
+
if not apps_local:
|
| 741 |
+
logger.warning(
|
| 742 |
+
"APPS dataset unavailable via HF and local APPS JSON missing in ./data/raw."
|
| 743 |
+
)
|
| 744 |
+
for local_file in apps_local:
|
| 745 |
+
if emitted >= limit:
|
| 746 |
+
break
|
| 747 |
+
for item in tqdm(
|
| 748 |
+
_iter_jsonl(local_file),
|
| 749 |
+
desc=f"problem_apps_local:{local_file.name}",
|
| 750 |
+
unit="rows",
|
| 751 |
+
):
|
| 752 |
+
if emitted >= limit:
|
| 753 |
+
break
|
| 754 |
+
problem = _safe_get(item, ["question", "prompt", "problem", "statement"])
|
| 755 |
+
solution = _safe_get(item, ["solution", "answer", "code"])
|
| 756 |
+
if not problem or not solution:
|
| 757 |
+
continue
|
| 758 |
+
emitted += 1
|
| 759 |
+
yield build_instruction_sample(
|
| 760 |
+
instruction=f"Solve the following problem:\n\n{problem}",
|
| 761 |
+
response=solution,
|
| 762 |
+
source="problem_apps_local",
|
| 763 |
+
category="problem",
|
| 764 |
+
)
|
| 765 |
+
continue
|
| 766 |
+
for item in tqdm(ds, desc=f"problem_{source_name}", unit="rows"):
|
| 767 |
+
if emitted >= limit:
|
| 768 |
+
break
|
| 769 |
+
title = _safe_get(item, ["title", "name", "problem_id", "task_id"])
|
| 770 |
+
base_instruction = ""
|
| 771 |
+
solutions: List[str] = []
|
| 772 |
+
if source_name.endswith("mbpp"):
|
| 773 |
+
problem = _safe_get(item, ["text"])
|
| 774 |
+
tests = item.get("test_list") or []
|
| 775 |
+
test_blob = "\n".join(tests) if isinstance(tests, list) else _decode_text(tests)
|
| 776 |
+
if test_blob:
|
| 777 |
+
problem = f"{problem}\n\nTests:\n{test_blob}"
|
| 778 |
+
sol = _safe_get(item, ["code"])
|
| 779 |
+
solutions = [sol] if sol else []
|
| 780 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 781 |
+
elif source_name.endswith("humaneval"):
|
| 782 |
+
problem = _safe_get(item, ["prompt"])
|
| 783 |
+
tests = _safe_get(item, ["test"])
|
| 784 |
+
if tests:
|
| 785 |
+
problem = f"{problem}\n\nTests:\n{tests}"
|
| 786 |
+
sol = _safe_get(item, ["canonical_solution"])
|
| 787 |
+
solutions = [sol] if sol else []
|
| 788 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 789 |
+
elif source_name.endswith("code_contests"):
|
| 790 |
+
problem = _safe_get(item, ["description", "problem", "question", "prompt"])
|
| 791 |
+
solutions = _extract_many_code_contests_solutions(item, max_per_problem=6)
|
| 792 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 793 |
+
elif source_name.endswith("apps"):
|
| 794 |
+
problem = _safe_get(item, ["question", "problem", "prompt", "statement"])
|
| 795 |
+
solutions = _extract_many_apps_solutions(item, max_per_problem=5)
|
| 796 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 797 |
+
elif source_name.endswith("open-r1/codeforces"):
|
| 798 |
+
problem = _safe_get(
|
| 799 |
+
item,
|
| 800 |
+
["problem", "statement", "question", "prompt", "description", "content"],
|
| 801 |
+
)
|
| 802 |
+
solutions = _extract_many_generic_solutions(item, max_per_problem=6)
|
| 803 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 804 |
+
else:
|
| 805 |
+
problem = _safe_get(item, ["content", "description", "question", "prompt", "statement"])
|
| 806 |
+
langs = [
|
| 807 |
+
_safe_get(item, ["python"]),
|
| 808 |
+
_safe_get(item, ["javascript"]),
|
| 809 |
+
_safe_get(item, ["java"]),
|
| 810 |
+
_safe_get(item, ["c++"]),
|
| 811 |
+
_safe_get(item, ["answer"]),
|
| 812 |
+
_safe_get(item, ["code"]),
|
| 813 |
+
]
|
| 814 |
+
solutions = [s for s in langs if s]
|
| 815 |
+
if isinstance(item.get("solutions"), list):
|
| 816 |
+
for extra in item["solutions"]:
|
| 817 |
+
t = _decode_text(extra).strip()
|
| 818 |
+
if t and t not in solutions:
|
| 819 |
+
solutions.append(t)
|
| 820 |
+
base_instruction = f"Solve this coding problem: {title}\n\n{problem}"
|
| 821 |
+
if not problem or not solutions:
|
| 822 |
+
continue
|
| 823 |
+
for sol in solutions:
|
| 824 |
+
if emitted >= limit:
|
| 825 |
+
break
|
| 826 |
+
if not sol or len(sol.strip()) < 20:
|
| 827 |
+
continue
|
| 828 |
+
emitted += 1
|
| 829 |
+
yield build_instruction_sample(
|
| 830 |
+
instruction=base_instruction,
|
| 831 |
+
response=sol,
|
| 832 |
+
source=f"problem_{source_name.replace('/', '_')}",
|
| 833 |
+
category="problem",
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
# Final problem fallback from CodeSearchNet docstrings to boost high-quality problem pairs.
|
| 837 |
+
if emitted < limit:
|
| 838 |
+
synth_cap = int(limit * 0.30)
|
| 839 |
+
cache_path = Path("./data/cache/raw/code_search_net_python")
|
| 840 |
+
ds = None
|
| 841 |
+
if cache_path.exists():
|
| 842 |
+
try:
|
| 843 |
+
ds = load_from_disk(str(cache_path))["train"]
|
| 844 |
+
logger.info("Using cached CodeSearchNet Python for problem fallback.")
|
| 845 |
+
except Exception:
|
| 846 |
+
ds = None
|
| 847 |
+
if ds is None:
|
| 848 |
+
try:
|
| 849 |
+
ds = load_dataset("code_search_net", "python", split="train", streaming=True)
|
| 850 |
+
logger.info("Using streamed CodeSearchNet Python for problem fallback.")
|
| 851 |
+
except Exception as exc:
|
| 852 |
+
logger.warning("Problem fallback CodeSearchNet failed: %s", exc)
|
| 853 |
+
ds = None
|
| 854 |
+
if ds is not None:
|
| 855 |
+
for item in tqdm(ds, desc="problem_codesearchnet_fallback", unit="rows"):
|
| 856 |
+
if emitted >= limit or synth_emitted >= synth_cap:
|
| 857 |
+
break
|
| 858 |
+
doc = _safe_get(item, ["docstring", "func_documentation_string"])
|
| 859 |
+
code = _safe_get(item, ["whole_func_string", "code"])
|
| 860 |
+
if len(doc.strip()) < 30 or not code:
|
| 861 |
+
continue
|
| 862 |
+
emitted += 1
|
| 863 |
+
synth_emitted += 1
|
| 864 |
+
yield build_instruction_sample(
|
| 865 |
+
instruction=f"Solve the following programming task:\n\n{doc}",
|
| 866 |
+
response=code,
|
| 867 |
+
source="codesearchnet_problem_fallback",
|
| 868 |
+
category="problem",
|
| 869 |
+
)
|
| 870 |
|
| 871 |
+
return _write_jsonl(raw_path, rows())
|
|
|
|
|
|
|
| 872 |
|
|
|
|
| 873 |
|
| 874 |
+
def fetch_problem_codeforces(raw_path: Path, limit: int, logger) -> int:
|
| 875 |
+
source_file = RAW_DIR / "codeforces.jsonl"
|
| 876 |
+
if not source_file.exists():
|
| 877 |
+
logger.warning("Codeforces dataset file not found: %s", source_file.resolve())
|
| 878 |
+
return 0
|
| 879 |
|
| 880 |
+
def rows():
|
| 881 |
+
emitted = 0
|
| 882 |
+
for item in tqdm(_iter_jsonl(source_file), desc="problem_codeforces", unit="rows"):
|
| 883 |
+
if emitted >= limit:
|
| 884 |
+
break
|
| 885 |
+
instruction = _safe_get(item, ["instruction", "problem", "statement", "question"])
|
| 886 |
+
response = _safe_get(item, ["response", "solution", "answer", "code"])
|
| 887 |
+
if not instruction or not response:
|
| 888 |
+
continue
|
| 889 |
+
if not instruction.lower().startswith("solve the following problem"):
|
| 890 |
+
instruction = f"Solve the following problem:\n{instruction}"
|
| 891 |
+
emitted += 1
|
| 892 |
+
yield build_instruction_sample(
|
| 893 |
+
instruction=instruction,
|
| 894 |
+
response=response,
|
| 895 |
+
source="codeforces_dataset",
|
| 896 |
+
category="problem",
|
| 897 |
)
|
| 898 |
+
|
| 899 |
+
count = _write_jsonl(raw_path, rows())
|
| 900 |
+
logger.info("Loaded Codeforces pre-ingested samples: %d", count)
|
| 901 |
+
return count
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
def build_dataset(args) -> Path:
|
| 905 |
+
ensure_dirs([RAW_DIR, FINAL_DIR, LOG_DIR])
|
| 906 |
+
logger = setup_logger("data_fetch_build", LOG_DIR / "data_fetch.log")
|
| 907 |
+
|
| 908 |
+
logger.info("Starting production dataset build. target_size=%d", args.target_size)
|
| 909 |
+
logger.info("Raw dir: %s", RAW_DIR.resolve())
|
| 910 |
+
logger.info("Final dir: %s", FINAL_DIR.resolve())
|
| 911 |
+
|
| 912 |
+
fetch_plan = {
|
| 913 |
+
"codealpaca": (fetch_instruction_codealpaca, args.codealpaca_limit),
|
| 914 |
+
"evol_instruct_code": (fetch_instruction_evol, args.evol_limit),
|
| 915 |
+
"ultrachat_code": (fetch_instruction_ultrachat_code, args.ultrachat_limit),
|
| 916 |
+
"openhermes_code": (fetch_instruction_openhermes_code, min(args.openhermes_limit, 120_000)),
|
| 917 |
+
"codesearchnet_multilang": (fetch_structured_codesearchnet, args.codesearchnet_limit),
|
| 918 |
+
"github_curated_functions": (fetch_structured_github_functions, args.github_limit),
|
| 919 |
+
"codeforces_problem": (fetch_problem_codeforces, args.codeforces_limit),
|
| 920 |
+
"leetcode_competitive": (fetch_problem_leetcode, args.leetcode_limit),
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
raw_paths: List[Path] = []
|
| 924 |
+
if not args.skip_fetch:
|
| 925 |
+
for name, (fn, limit) in fetch_plan.items():
|
| 926 |
+
raw_path = RAW_DIR / f"{name}.jsonl"
|
| 927 |
+
raw_paths.append(raw_path)
|
| 928 |
+
try:
|
| 929 |
+
count = fn(raw_path, limit, logger)
|
| 930 |
+
logger.info("Fetched %d rows for source=%s", count, name)
|
| 931 |
+
except Exception as exc:
|
| 932 |
+
logger.warning("Skipping source=%s due to fetch error: %s", name, exc)
|
| 933 |
else:
|
| 934 |
+
raw_paths = sorted(RAW_DIR.glob("*.jsonl"))
|
| 935 |
+
logger.info("Skip fetch enabled. Using existing raw files: %d", len(raw_paths))
|
| 936 |
+
|
| 937 |
+
# Phase 1: base balanced build (streaming + dedupe).
|
| 938 |
+
stats = build_balanced_dataset(
|
| 939 |
+
input_paths=raw_paths,
|
| 940 |
+
output_path=FINAL_TRAIN,
|
| 941 |
+
target_size=args.target_size,
|
| 942 |
+
min_tokens=args.min_tokens,
|
| 943 |
+
max_tokens=args.max_tokens,
|
| 944 |
+
num_workers=args.workers,
|
| 945 |
+
category_weights={"instruction": 0.60, "structured": 0.30, "problem": 0.10},
|
| 946 |
+
sqlite_path=FINAL_DIR / "dedupe_hashes.sqlite",
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
# Phase 2: post-build strict rebalance (downsample excess + upsample deficits).
|
| 950 |
+
rebalance_stats = rebalance_final_dataset(
|
| 951 |
+
raw_paths=raw_paths,
|
| 952 |
+
output_path=FINAL_TRAIN,
|
| 953 |
+
target_size=args.target_size,
|
| 954 |
+
min_tokens=args.min_tokens,
|
| 955 |
+
max_tokens=args.max_tokens,
|
| 956 |
+
min_problem_samples=args.min_problem_samples,
|
| 957 |
+
logger=logger,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
actual_problem = int(rebalance_stats["category_breakdown"].get("problem", 0))
|
| 961 |
+
required_problem = int(args.min_problem_samples)
|
| 962 |
+
real_problem = int(rebalance_stats.get("problem_real_count", 0))
|
| 963 |
+
synthetic_problem = int(rebalance_stats.get("problem_synthetic_count", 0))
|
| 964 |
+
synthetic_ratio = synthetic_problem / max(actual_problem, 1)
|
| 965 |
+
|
| 966 |
+
if actual_problem < max(required_problem, args.min_total_problem_samples):
|
| 967 |
+
raise RuntimeError(
|
| 968 |
+
"Build aborted: insufficient problem-solving data after rebalance. "
|
| 969 |
+
f"Required >= {max(required_problem, args.min_total_problem_samples)}, actual = {actual_problem}. "
|
| 970 |
+
"Increase problem dataset sources (e.g., leetcode/code contests/problem-solution datasets) "
|
| 971 |
+
"or raise problem fetch limits, then rebuild."
|
| 972 |
+
)
|
| 973 |
+
if real_problem < args.min_real_problem_samples:
|
| 974 |
+
raise RuntimeError(
|
| 975 |
+
"Build aborted: insufficient REAL problem-solving data after rebalance. "
|
| 976 |
+
f"Required real >= {args.min_real_problem_samples}, actual real = {real_problem}. "
|
| 977 |
+
"Add more high-quality real problem datasets (APPS/CodeContests/Codeforces/LeetCode)."
|
| 978 |
+
)
|
| 979 |
+
if synthetic_ratio > args.max_synthetic_problem_ratio:
|
| 980 |
+
raise RuntimeError(
|
| 981 |
+
"Build aborted: synthetic problem share too high. "
|
| 982 |
+
f"Allowed <= {args.max_synthetic_problem_ratio:.0%}, actual = {synthetic_ratio:.2%}. "
|
| 983 |
+
"Increase real problem sources and reduce synthetic fallback usage."
|
| 984 |
+
)
|
| 985 |
|
| 986 |
+
logger.info("Build complete. Final dataset: %s", FINAL_TRAIN.resolve())
|
| 987 |
+
logger.info("Base stats: %s", stats)
|
| 988 |
+
logger.info("Rebalanced stats: %s", rebalance_stats)
|
|
|
|
| 989 |
|
| 990 |
+
print(f"Final dataset: {FINAL_TRAIN.resolve()}")
|
| 991 |
+
print(f"Total samples: {rebalance_stats['total_samples']}")
|
| 992 |
+
print(f"Avg length (tokens est.): {rebalance_stats['avg_length_tokens']}")
|
| 993 |
+
print("Per-source breakdown:")
|
| 994 |
+
for src, count in sorted(
|
| 995 |
+
rebalance_stats["source_breakdown"].items(), key=lambda x: x[1], reverse=True
|
| 996 |
+
):
|
| 997 |
+
print(f" - {src}: {count}")
|
| 998 |
+
print("Category breakdown:")
|
| 999 |
+
for cat, count in sorted(rebalance_stats["category_breakdown"].items(), key=lambda x: x[0]):
|
| 1000 |
+
print(f" - {cat}: {count} (target: {rebalance_stats['targets'].get(cat, 0)})")
|
| 1001 |
+
ratio = rebalance_stats["instruction_vs_raw_ratio"]
|
| 1002 |
+
print(
|
| 1003 |
+
f"Instruction vs raw-converted ratio: {ratio['instruction_pct']}% / {ratio['raw_converted_pct']}%"
|
| 1004 |
+
)
|
| 1005 |
+
total = max(1, rebalance_stats["total_samples"])
|
| 1006 |
+
print("Category percentages:")
|
| 1007 |
+
for cat in ("instruction", "structured", "problem"):
|
| 1008 |
+
pct = 100.0 * rebalance_stats["category_breakdown"].get(cat, 0) / total
|
| 1009 |
+
print(f" - {cat}: {pct:.2f}%")
|
| 1010 |
+
print(f"Real problem count: {real_problem}")
|
| 1011 |
+
print(f"Synthetic problem count: {synthetic_problem}")
|
| 1012 |
+
print(f"Synthetic problem %: {synthetic_ratio * 100:.2f}%")
|
| 1013 |
+
return FINAL_TRAIN
|
| 1014 |
|
|
|
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
+
def _build_parser() -> argparse.ArgumentParser:
|
| 1017 |
+
parser = argparse.ArgumentParser(description="Production-grade coding dataset build pipeline.")
|
| 1018 |
+
parser.add_argument("--build", action="store_true", help="Run the full build pipeline.")
|
| 1019 |
+
parser.add_argument("--target-size", type=int, default=1_000_000)
|
| 1020 |
+
parser.add_argument("--min-tokens", type=int, default=10)
|
| 1021 |
+
parser.add_argument("--max-tokens", type=int, default=2048)
|
| 1022 |
+
parser.add_argument("--skip-fetch", action="store_true", help="Use existing ./data/raw/*.jsonl only.")
|
| 1023 |
+
parser.add_argument(
|
| 1024 |
+
"--workers",
|
| 1025 |
+
type=int,
|
| 1026 |
+
default=max(1, (os.cpu_count() or 4) // 2),
|
| 1027 |
+
help="Parallel worker processes for cleaning stage.",
|
| 1028 |
+
)
|
| 1029 |
|
| 1030 |
+
parser.add_argument("--codealpaca-limit", type=int, default=20000)
|
| 1031 |
+
parser.add_argument("--evol-limit", type=int, default=300000)
|
| 1032 |
+
parser.add_argument("--ultrachat-limit", type=int, default=250000)
|
| 1033 |
+
parser.add_argument("--openhermes-limit", type=int, default=250000)
|
| 1034 |
+
parser.add_argument("--codesearchnet-limit", type=int, default=300000)
|
| 1035 |
+
parser.add_argument("--github-limit", type=int, default=200000)
|
| 1036 |
+
parser.add_argument("--codeforces-limit", type=int, default=200000)
|
| 1037 |
+
parser.add_argument("--leetcode-limit", type=int, default=300000)
|
| 1038 |
+
parser.add_argument(
|
| 1039 |
+
"--stackoverflow-limit",
|
| 1040 |
+
type=int,
|
| 1041 |
+
default=0,
|
| 1042 |
+
help="Deprecated. StackOverflow sources were removed due unreliability.",
|
| 1043 |
+
)
|
| 1044 |
parser.add_argument(
|
| 1045 |
+
"--min-problem-samples",
|
| 1046 |
+
type=int,
|
| 1047 |
+
default=50_000,
|
| 1048 |
+
help="Ensure at least this many samples in problem category during post-rebalance.",
|
| 1049 |
+
)
|
| 1050 |
+
parser.add_argument(
|
| 1051 |
+
"--min-real-problem-samples",
|
| 1052 |
+
type=int,
|
| 1053 |
+
default=50_000,
|
| 1054 |
+
help="Minimum REAL problem samples required after rebalance.",
|
| 1055 |
+
)
|
| 1056 |
+
parser.add_argument(
|
| 1057 |
+
"--min-total-problem-samples",
|
| 1058 |
+
type=int,
|
| 1059 |
+
default=80_000,
|
| 1060 |
+
help="Minimum total problem samples required after rebalance.",
|
| 1061 |
+
)
|
| 1062 |
+
parser.add_argument(
|
| 1063 |
+
"--max-synthetic-problem-ratio",
|
| 1064 |
+
type=float,
|
| 1065 |
+
default=0.30,
|
| 1066 |
+
help="Maximum allowed synthetic (docstring fallback) share in problem category.",
|
| 1067 |
)
|
| 1068 |
return parser
|
| 1069 |
|
| 1070 |
|
| 1071 |
if __name__ == "__main__":
|
| 1072 |
+
parser = _build_parser()
|
| 1073 |
+
args = parser.parse_args()
|
| 1074 |
+
if args.build:
|
| 1075 |
+
build_dataset(args)
|
| 1076 |
+
else:
|
| 1077 |
+
parser.print_help()
|
dataset_cleaner.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import json
|
| 3 |
+
import multiprocessing as mp
|
| 4 |
+
import re
|
| 5 |
+
import sqlite3
|
| 6 |
+
from collections import Counter, defaultdict
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Iterable, Iterator, List, Optional
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
|
| 14 |
+
CODE_PATTERN = re.compile(
|
| 15 |
+
r"(\bdef\b|\bclass\b|\bimport\b|\breturn\b|=>|function\s+\w+|public\s+class|#include|```)",
|
| 16 |
+
re.IGNORECASE,
|
| 17 |
+
)
|
| 18 |
+
EXPLANATION_PATTERN = re.compile(
|
| 19 |
+
r"\b(explain|because|algorithm|steps|approach|complexity|solution)\b", re.IGNORECASE
|
| 20 |
+
)
|
| 21 |
+
PROBLEM_PROMPT_RE = re.compile(
|
| 22 |
+
r"\b(solve|given|find|compute|return|input|output|problem|algorithm|task|challenge)\b",
|
| 23 |
+
re.IGNORECASE,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def estimate_tokens(text: str) -> int:
|
| 28 |
+
if not text:
|
| 29 |
+
return 0
|
| 30 |
+
return len(TOKEN_PATTERN.findall(text))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def normalize_text(text: str) -> str:
|
| 34 |
+
if text is None:
|
| 35 |
+
return ""
|
| 36 |
+
text = str(text).replace("\x00", "")
|
| 37 |
+
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
| 38 |
+
text = "".join(ch for ch in text if ch == "\n" or ch == "\t" or ord(ch) >= 32)
|
| 39 |
+
lines = [line.rstrip() for line in text.split("\n")]
|
| 40 |
+
return "\n".join(lines).strip()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _ascii_ratio(text: str) -> float:
|
| 44 |
+
if not text:
|
| 45 |
+
return 1.0
|
| 46 |
+
ascii_count = sum(1 for c in text if ord(c) < 128)
|
| 47 |
+
return ascii_count / len(text)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _response_is_valid(response: str) -> bool:
|
| 51 |
+
if not response:
|
| 52 |
+
return False
|
| 53 |
+
if CODE_PATTERN.search(response):
|
| 54 |
+
return True
|
| 55 |
+
if EXPLANATION_PATTERN.search(response):
|
| 56 |
+
return True
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _response_has_code(response: str) -> bool:
|
| 61 |
+
return bool(
|
| 62 |
+
re.search(
|
| 63 |
+
r"(\bdef\b|\bclass\b|\breturn\b|\bimport\b|```|function\s+\w+|public\s+class|#include|SELECT\s+)",
|
| 64 |
+
response,
|
| 65 |
+
re.IGNORECASE,
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def clean_record(
|
| 71 |
+
record: Dict[str, str],
|
| 72 |
+
*,
|
| 73 |
+
min_tokens: int = 10,
|
| 74 |
+
max_tokens: int = 2048,
|
| 75 |
+
) -> Optional[Dict[str, str]]:
|
| 76 |
+
instruction = normalize_text(record.get("instruction", ""))
|
| 77 |
+
response = normalize_text(record.get("response", ""))
|
| 78 |
+
source = normalize_text(record.get("_source", "unknown"))
|
| 79 |
+
category = normalize_text(record.get("_category", ""))
|
| 80 |
+
if not category:
|
| 81 |
+
src_low = source.lower()
|
| 82 |
+
if any(k in src_low for k in ("codealpaca", "evol", "ultrachat", "openhermes", "orca")):
|
| 83 |
+
category = "instruction"
|
| 84 |
+
elif any(
|
| 85 |
+
k in src_low
|
| 86 |
+
for k in (
|
| 87 |
+
"leetcode",
|
| 88 |
+
"contest",
|
| 89 |
+
"mbpp",
|
| 90 |
+
"humaneval",
|
| 91 |
+
"apps",
|
| 92 |
+
"codeforces",
|
| 93 |
+
"problem",
|
| 94 |
+
"codesearchnet_problem",
|
| 95 |
+
)
|
| 96 |
+
):
|
| 97 |
+
category = "problem"
|
| 98 |
+
else:
|
| 99 |
+
category = "structured"
|
| 100 |
+
|
| 101 |
+
if not instruction or not response:
|
| 102 |
+
return None
|
| 103 |
+
if _ascii_ratio(instruction + response) < 0.85:
|
| 104 |
+
return None
|
| 105 |
+
if not _response_is_valid(response):
|
| 106 |
+
return None
|
| 107 |
+
if category == "problem":
|
| 108 |
+
if len(instruction) <= 50:
|
| 109 |
+
return None
|
| 110 |
+
if not PROBLEM_PROMPT_RE.search(instruction):
|
| 111 |
+
return None
|
| 112 |
+
if not _response_has_code(response):
|
| 113 |
+
return None
|
| 114 |
+
# Problem solutions must include code, not explanation-only text.
|
| 115 |
+
if EXPLANATION_PATTERN.search(response) and not CODE_PATTERN.search(response):
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
total_tokens = estimate_tokens(instruction) + estimate_tokens(response)
|
| 119 |
+
if total_tokens < min_tokens or total_tokens > max_tokens:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
return {
|
| 123 |
+
"instruction": instruction,
|
| 124 |
+
"response": response,
|
| 125 |
+
"_source": source,
|
| 126 |
+
"_category": category,
|
| 127 |
+
"_tokens": total_tokens,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _iter_jsonl(path: Path) -> Iterable[Dict[str, str]]:
|
| 132 |
+
with path.open("r", encoding="utf-8") as f:
|
| 133 |
+
for line in f:
|
| 134 |
+
line = line.strip()
|
| 135 |
+
if not line:
|
| 136 |
+
continue
|
| 137 |
+
try:
|
| 138 |
+
yield json.loads(line)
|
| 139 |
+
except json.JSONDecodeError:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _clean_record_worker(payload: Dict[str, object]) -> Optional[Dict[str, str]]:
|
| 144 |
+
record = payload["record"]
|
| 145 |
+
min_tokens = int(payload["min_tokens"])
|
| 146 |
+
max_tokens = int(payload["max_tokens"])
|
| 147 |
+
return clean_record(record, min_tokens=min_tokens, max_tokens=max_tokens)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def iter_cleaned_records(
|
| 151 |
+
path: Path,
|
| 152 |
+
*,
|
| 153 |
+
min_tokens: int,
|
| 154 |
+
max_tokens: int,
|
| 155 |
+
num_workers: int = 1,
|
| 156 |
+
batch_size: int = 2000,
|
| 157 |
+
) -> Iterator[Dict[str, str]]:
|
| 158 |
+
if num_workers <= 1:
|
| 159 |
+
for record in _iter_jsonl(path):
|
| 160 |
+
cleaned = clean_record(record, min_tokens=min_tokens, max_tokens=max_tokens)
|
| 161 |
+
if cleaned is not None:
|
| 162 |
+
yield cleaned
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
pool = mp.Pool(processes=num_workers)
|
| 166 |
+
try:
|
| 167 |
+
batch: List[Dict[str, str]] = []
|
| 168 |
+
for record in _iter_jsonl(path):
|
| 169 |
+
batch.append(record)
|
| 170 |
+
if len(batch) < batch_size:
|
| 171 |
+
continue
|
| 172 |
+
payloads = [
|
| 173 |
+
{"record": r, "min_tokens": min_tokens, "max_tokens": max_tokens} for r in batch
|
| 174 |
+
]
|
| 175 |
+
for cleaned in pool.imap_unordered(_clean_record_worker, payloads, chunksize=64):
|
| 176 |
+
if cleaned is not None:
|
| 177 |
+
yield cleaned
|
| 178 |
+
batch.clear()
|
| 179 |
+
|
| 180 |
+
if batch:
|
| 181 |
+
payloads = [{"record": r, "min_tokens": min_tokens, "max_tokens": max_tokens} for r in batch]
|
| 182 |
+
for cleaned in pool.imap_unordered(_clean_record_worker, payloads, chunksize=64):
|
| 183 |
+
if cleaned is not None:
|
| 184 |
+
yield cleaned
|
| 185 |
+
finally:
|
| 186 |
+
pool.close()
|
| 187 |
+
pool.join()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _remove_sqlite_artifacts(sqlite_path: Path) -> None:
|
| 191 |
+
if sqlite_path.exists():
|
| 192 |
+
sqlite_path.unlink()
|
| 193 |
+
for suffix in ("-wal", "-shm"):
|
| 194 |
+
p = sqlite_path.with_name(sqlite_path.name + suffix)
|
| 195 |
+
if p.exists():
|
| 196 |
+
p.unlink()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _open_dedupe_db(sqlite_path: Path):
|
| 200 |
+
sqlite_path = sqlite_path.resolve()
|
| 201 |
+
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
| 202 |
+
_remove_sqlite_artifacts(sqlite_path)
|
| 203 |
+
conn = sqlite3.connect(str(sqlite_path))
|
| 204 |
+
conn.execute("PRAGMA journal_mode=WAL;")
|
| 205 |
+
conn.execute("CREATE TABLE IF NOT EXISTS seen_hashes (h TEXT PRIMARY KEY)")
|
| 206 |
+
return conn
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _is_duplicate(conn, instruction: str, response: str) -> bool:
|
| 210 |
+
digest = hashlib.sha256(f"{instruction}||{response}".encode("utf-8")).hexdigest()
|
| 211 |
+
try:
|
| 212 |
+
conn.execute("INSERT INTO seen_hashes(h) VALUES (?)", (digest,))
|
| 213 |
+
return False
|
| 214 |
+
except sqlite3.IntegrityError:
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def build_balanced_dataset(
|
| 219 |
+
input_paths: List[Path],
|
| 220 |
+
output_path: Path,
|
| 221 |
+
*,
|
| 222 |
+
target_size: int = 1_000_000,
|
| 223 |
+
min_tokens: int = 10,
|
| 224 |
+
max_tokens: int = 2048,
|
| 225 |
+
category_weights: Optional[Dict[str, float]] = None,
|
| 226 |
+
sqlite_path: Optional[Path] = None,
|
| 227 |
+
num_workers: int = 1,
|
| 228 |
+
) -> Dict[str, object]:
|
| 229 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 230 |
+
if sqlite_path is None:
|
| 231 |
+
sqlite_path = output_path.parent / "dedupe_hashes.sqlite"
|
| 232 |
+
conn = _open_dedupe_db(sqlite_path)
|
| 233 |
+
|
| 234 |
+
weights = category_weights or {"instruction": 0.60, "structured": 0.30, "problem": 0.10}
|
| 235 |
+
target_by_cat = {k: int(target_size * v) for k, v in weights.items()}
|
| 236 |
+
target_by_cat["problem"] = target_size - target_by_cat["instruction"] - target_by_cat["structured"]
|
| 237 |
+
|
| 238 |
+
grouped_paths: Dict[str, List[Path]] = defaultdict(list)
|
| 239 |
+
for path in input_paths:
|
| 240 |
+
if not path.exists():
|
| 241 |
+
continue
|
| 242 |
+
name = path.stem
|
| 243 |
+
if "codealpaca" in name or "evol" in name or "ultrachat" in name or "openhermes" in name:
|
| 244 |
+
grouped_paths["instruction"].append(path)
|
| 245 |
+
elif any(
|
| 246 |
+
k in name
|
| 247 |
+
for k in (
|
| 248 |
+
"leetcode",
|
| 249 |
+
"contest",
|
| 250 |
+
"problem",
|
| 251 |
+
"mbpp",
|
| 252 |
+
"humaneval",
|
| 253 |
+
"apps",
|
| 254 |
+
"codeforces",
|
| 255 |
+
)
|
| 256 |
+
):
|
| 257 |
+
grouped_paths["problem"].append(path)
|
| 258 |
+
else:
|
| 259 |
+
grouped_paths["structured"].append(path)
|
| 260 |
+
|
| 261 |
+
source_counter = Counter()
|
| 262 |
+
category_counter = Counter()
|
| 263 |
+
total_tokens = 0
|
| 264 |
+
total_kept = 0
|
| 265 |
+
|
| 266 |
+
def try_write(cleaned: Dict[str, str], out_f, enforce_category_target: bool) -> bool:
|
| 267 |
+
nonlocal total_kept, total_tokens
|
| 268 |
+
category = cleaned["_category"]
|
| 269 |
+
if enforce_category_target and category_counter[category] >= target_by_cat.get(category, 0):
|
| 270 |
+
return False
|
| 271 |
+
if _is_duplicate(conn, cleaned["instruction"], cleaned["response"]):
|
| 272 |
+
return False
|
| 273 |
+
source = cleaned["_source"]
|
| 274 |
+
tokens = int(cleaned["_tokens"])
|
| 275 |
+
category_counter[category] += 1
|
| 276 |
+
source_counter[source] += 1
|
| 277 |
+
total_tokens += tokens
|
| 278 |
+
total_kept += 1
|
| 279 |
+
out_f.write(
|
| 280 |
+
json.dumps(
|
| 281 |
+
{"instruction": cleaned["instruction"], "response": cleaned["response"]},
|
| 282 |
+
ensure_ascii=False,
|
| 283 |
+
)
|
| 284 |
+
+ "\n"
|
| 285 |
+
)
|
| 286 |
+
return True
|
| 287 |
+
|
| 288 |
+
with output_path.open("w", encoding="utf-8") as out_f:
|
| 289 |
+
# Phase 1: enforce 60/30/10 quotas.
|
| 290 |
+
for category in ("instruction", "structured", "problem"):
|
| 291 |
+
if category not in grouped_paths:
|
| 292 |
+
continue
|
| 293 |
+
for path in grouped_paths[category]:
|
| 294 |
+
cleaned_iter = iter_cleaned_records(
|
| 295 |
+
path,
|
| 296 |
+
min_tokens=min_tokens,
|
| 297 |
+
max_tokens=max_tokens,
|
| 298 |
+
num_workers=num_workers,
|
| 299 |
+
)
|
| 300 |
+
for cleaned in tqdm(cleaned_iter, desc=f"balance1:{path.name}", unit="rows"):
|
| 301 |
+
if total_kept >= target_size or category_counter[category] >= target_by_cat[category]:
|
| 302 |
+
break
|
| 303 |
+
try_write(cleaned, out_f, enforce_category_target=True)
|
| 304 |
+
conn.commit()
|
| 305 |
+
if total_kept >= target_size or category_counter[category] >= target_by_cat[category]:
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
# Phase 2: fill remaining slots from all categories while preserving dedupe.
|
| 309 |
+
if total_kept < target_size:
|
| 310 |
+
for path in input_paths:
|
| 311 |
+
if not path.exists():
|
| 312 |
+
continue
|
| 313 |
+
cleaned_iter = iter_cleaned_records(
|
| 314 |
+
path,
|
| 315 |
+
min_tokens=min_tokens,
|
| 316 |
+
max_tokens=max_tokens,
|
| 317 |
+
num_workers=num_workers,
|
| 318 |
+
)
|
| 319 |
+
for cleaned in tqdm(cleaned_iter, desc=f"balance2:{path.name}", unit="rows"):
|
| 320 |
+
if total_kept >= target_size:
|
| 321 |
+
break
|
| 322 |
+
try_write(cleaned, out_f, enforce_category_target=False)
|
| 323 |
+
conn.commit()
|
| 324 |
+
if total_kept >= target_size:
|
| 325 |
+
break
|
| 326 |
+
|
| 327 |
+
conn.close()
|
| 328 |
+
avg_len = round((total_tokens / total_kept), 2) if total_kept else 0.0
|
| 329 |
+
raw_converted = category_counter["structured"] + category_counter["problem"]
|
| 330 |
+
ratio = {
|
| 331 |
+
"instruction_pct": round(100.0 * category_counter["instruction"] / max(total_kept, 1), 2),
|
| 332 |
+
"raw_converted_pct": round(100.0 * raw_converted / max(total_kept, 1), 2),
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
return {
|
| 336 |
+
"total_samples": total_kept,
|
| 337 |
+
"avg_length_tokens": avg_len,
|
| 338 |
+
"source_breakdown": dict(source_counter),
|
| 339 |
+
"category_breakdown": dict(category_counter),
|
| 340 |
+
"instruction_vs_raw_ratio": ratio,
|
| 341 |
+
"targets": target_by_cat,
|
| 342 |
+
}
|
dataset_formatter.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
FUNC_RE = re.compile(r"\bdef\s+([a-zA-Z_]\w*)\s*\(|\bfunction\s+([a-zA-Z_]\w*)\s*\(")
|
| 6 |
+
CLASS_RE = re.compile(r"\bclass\s+([a-zA-Z_]\w*)")
|
| 7 |
+
DOCSTRING_RE = re.compile(r'"""(.*?)"""|\'\'\'(.*?)\'\'\'', re.DOTALL)
|
| 8 |
+
COMMENT_RE = re.compile(r"^\s*(#|//)\s*(.+)$", re.MULTILINE)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def normalize_spaces(text: str) -> str:
|
| 12 |
+
if not text:
|
| 13 |
+
return ""
|
| 14 |
+
return text.replace("\r\n", "\n").replace("\r", "\n").strip()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _first_non_empty(*vals: Optional[str]) -> str:
|
| 18 |
+
for v in vals:
|
| 19 |
+
if v and str(v).strip():
|
| 20 |
+
return str(v).strip()
|
| 21 |
+
return ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def infer_language(lang: str = "", path: str = "") -> str:
|
| 25 |
+
lang = (lang or "").lower()
|
| 26 |
+
path = (path or "").lower()
|
| 27 |
+
if lang:
|
| 28 |
+
return lang
|
| 29 |
+
if path.endswith(".py"):
|
| 30 |
+
return "python"
|
| 31 |
+
if path.endswith(".js"):
|
| 32 |
+
return "javascript"
|
| 33 |
+
if path.endswith(".ts"):
|
| 34 |
+
return "typescript"
|
| 35 |
+
if path.endswith(".java"):
|
| 36 |
+
return "java"
|
| 37 |
+
return "code"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def extract_function_name(code: str) -> str:
|
| 41 |
+
if not code:
|
| 42 |
+
return ""
|
| 43 |
+
m = FUNC_RE.search(code)
|
| 44 |
+
if m:
|
| 45 |
+
return m.group(1) or m.group(2) or ""
|
| 46 |
+
c = CLASS_RE.search(code)
|
| 47 |
+
if c:
|
| 48 |
+
return c.group(1) or ""
|
| 49 |
+
return ""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_doc_or_comment(code: str) -> str:
|
| 53 |
+
if not code:
|
| 54 |
+
return ""
|
| 55 |
+
doc = DOCSTRING_RE.search(code)
|
| 56 |
+
if doc:
|
| 57 |
+
return _first_non_empty(doc.group(1), doc.group(2))
|
| 58 |
+
com = COMMENT_RE.search(code)
|
| 59 |
+
if com:
|
| 60 |
+
return com.group(2).strip()
|
| 61 |
+
return ""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def code_to_instruction(code: str, *, language: str = "", path: str = "", title: str = "") -> str:
|
| 65 |
+
code = normalize_spaces(code)
|
| 66 |
+
lang = infer_language(language, path)
|
| 67 |
+
func = extract_function_name(code)
|
| 68 |
+
hint = _first_non_empty(title, extract_doc_or_comment(code))
|
| 69 |
+
|
| 70 |
+
if func and hint:
|
| 71 |
+
return f"Write a {lang} implementation of `{func}`. Requirements: {hint}"
|
| 72 |
+
if func:
|
| 73 |
+
return f"Write a {lang} function `{func}`."
|
| 74 |
+
if hint:
|
| 75 |
+
return f"Implement this {lang} code task: {hint}"
|
| 76 |
+
if path:
|
| 77 |
+
return f"Implement or refactor the {lang} code from `{path}`."
|
| 78 |
+
return f"Write a correct and production-ready {lang} code snippet."
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_instruction_sample(
|
| 82 |
+
*,
|
| 83 |
+
instruction: str = "",
|
| 84 |
+
response: str = "",
|
| 85 |
+
code: str = "",
|
| 86 |
+
language: str = "",
|
| 87 |
+
path: str = "",
|
| 88 |
+
title: str = "",
|
| 89 |
+
source: str,
|
| 90 |
+
category: str,
|
| 91 |
+
) -> Dict[str, str]:
|
| 92 |
+
if not instruction:
|
| 93 |
+
instruction = code_to_instruction(code, language=language, path=path, title=title)
|
| 94 |
+
if not response:
|
| 95 |
+
response = code
|
| 96 |
+
return {
|
| 97 |
+
"instruction": normalize_spaces(instruction),
|
| 98 |
+
"response": normalize_spaces(response),
|
| 99 |
+
"_source": source,
|
| 100 |
+
"_category": category,
|
| 101 |
+
}
|
| 102 |
+
|
final_model/config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MindiForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_mindi.MindiConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_mindi.MindiForCausalLM",
|
| 8 |
+
"AutoTokenizer": [
|
| 9 |
+
null,
|
| 10 |
+
"tokenization_mindi.MindiTokenizer"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
"bos_token_id": 2,
|
| 14 |
+
"d_ff": 4608,
|
| 15 |
+
"d_model": 1152,
|
| 16 |
+
"dropout": 0.1,
|
| 17 |
+
"dtype": "float16",
|
| 18 |
+
"eos_token_id": 3,
|
| 19 |
+
"init_std": 0.02,
|
| 20 |
+
"max_seq_len": 2048,
|
| 21 |
+
"model_type": "mindi",
|
| 22 |
+
"n_heads": 16,
|
| 23 |
+
"n_layers": 23,
|
| 24 |
+
"pad_token_id": 0,
|
| 25 |
+
"rms_norm_eps": 1e-05,
|
| 26 |
+
"tie_embeddings": true,
|
| 27 |
+
"transformers_version": "5.4.0",
|
| 28 |
+
"vocab_size": 50000
|
| 29 |
+
}
|
final_model/configuration_mindi.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face config class for MINDI 1.0 420M.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from transformers import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MindiConfig(PretrainedConfig):
|
| 9 |
+
model_type = "mindi"
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
vocab_size=50000,
|
| 14 |
+
max_seq_len=2048,
|
| 15 |
+
d_model=1152,
|
| 16 |
+
n_layers=23,
|
| 17 |
+
n_heads=16,
|
| 18 |
+
d_ff=4608,
|
| 19 |
+
dropout=0.1,
|
| 20 |
+
tie_embeddings=True,
|
| 21 |
+
init_std=0.02,
|
| 22 |
+
rms_norm_eps=1e-5,
|
| 23 |
+
bos_token_id=2,
|
| 24 |
+
eos_token_id=3,
|
| 25 |
+
pad_token_id=0,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
|
| 29 |
+
self.vocab_size = vocab_size
|
| 30 |
+
self.max_seq_len = max_seq_len
|
| 31 |
+
self.d_model = d_model
|
| 32 |
+
self.n_layers = n_layers
|
| 33 |
+
self.n_heads = n_heads
|
| 34 |
+
self.d_ff = d_ff
|
| 35 |
+
self.dropout = dropout
|
| 36 |
+
self.tie_embeddings = tie_embeddings
|
| 37 |
+
self.init_std = init_std
|
| 38 |
+
self.rms_norm_eps = rms_norm_eps
|
final_model/generation_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 2,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": 3,
|
| 5 |
+
"max_new_tokens": 220,
|
| 6 |
+
"pad_token_id": 0,
|
| 7 |
+
"temperature": 0.2,
|
| 8 |
+
"top_p": 0.9,
|
| 9 |
+
"transformers_version": "5.4.0"
|
| 10 |
+
}
|
backup_step4000.tar.gz → final_model/model.safetensors
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c63f0d3f5cf8fca2fca36c1339b2c07d1c21378ce0753b007e78048607a66764
|
| 3 |
+
size 963088320
|
final_model/modeling_mindi.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face model class for MINDI 1.0 420M.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from transformers import PreTrainedModel
|
| 14 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 15 |
+
|
| 16 |
+
from .configuration_mindi import MindiConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class _Cfg:
|
| 21 |
+
vocab_size: int
|
| 22 |
+
max_seq_len: int
|
| 23 |
+
d_model: int
|
| 24 |
+
n_layers: int
|
| 25 |
+
n_heads: int
|
| 26 |
+
d_ff: int
|
| 27 |
+
dropout: float
|
| 28 |
+
tie_embeddings: bool
|
| 29 |
+
init_std: float
|
| 30 |
+
rms_norm_eps: float
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def head_dim(self) -> int:
|
| 34 |
+
if self.d_model % self.n_heads != 0:
|
| 35 |
+
raise ValueError("d_model must be divisible by n_heads")
|
| 36 |
+
return self.d_model // self.n_heads
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RMSNorm(nn.Module):
|
| 40 |
+
def __init__(self, dim: int, eps: float = 1e-5) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.eps = eps
|
| 43 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
norm = x.pow(2).mean(dim=-1, keepdim=True)
|
| 47 |
+
x = x * torch.rsqrt(norm + self.eps)
|
| 48 |
+
return self.weight * x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class RotaryEmbedding(nn.Module):
|
| 52 |
+
def __init__(self, head_dim: int, max_seq_len: int) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
if head_dim % 2 != 0:
|
| 55 |
+
raise ValueError("head_dim must be even for rotary embeddings")
|
| 56 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 57 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 58 |
+
freqs = torch.outer(t, inv_freq)
|
| 59 |
+
self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)
|
| 60 |
+
self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
|
| 61 |
+
|
| 62 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 63 |
+
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
| 64 |
+
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
|
| 65 |
+
return self._apply_rotary(q, cos, sin), self._apply_rotary(k, cos, sin)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
x1 = x[..., ::2]
|
| 70 |
+
x2 = x[..., 1::2]
|
| 71 |
+
xe = x1 * cos - x2 * sin
|
| 72 |
+
xo = x1 * sin + x2 * cos
|
| 73 |
+
return torch.stack((xe, xo), dim=-1).flatten(-2)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CausalSelfAttention(nn.Module):
|
| 77 |
+
def __init__(self, cfg: _Cfg) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.n_heads = cfg.n_heads
|
| 80 |
+
self.head_dim = cfg.head_dim
|
| 81 |
+
self.scale = self.head_dim ** -0.5
|
| 82 |
+
self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 83 |
+
self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 84 |
+
self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 85 |
+
self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 86 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 87 |
+
self.rotary = RotaryEmbedding(self.head_dim, cfg.max_seq_len)
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
bsz, seq_len, _ = x.shape
|
| 91 |
+
q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 92 |
+
k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 93 |
+
v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
|
| 94 |
+
q, k = self.rotary(q, k, seq_len=seq_len)
|
| 95 |
+
out = F.scaled_dot_product_attention(
|
| 96 |
+
q,
|
| 97 |
+
k,
|
| 98 |
+
v,
|
| 99 |
+
attn_mask=None,
|
| 100 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 101 |
+
is_causal=True,
|
| 102 |
+
scale=self.scale,
|
| 103 |
+
)
|
| 104 |
+
out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
|
| 105 |
+
return self.o_proj(out)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class FeedForward(nn.Module):
|
| 109 |
+
def __init__(self, cfg: _Cfg) -> None:
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.fc1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 112 |
+
self.fc2 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 113 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 114 |
+
|
| 115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
x = self.fc1(x)
|
| 117 |
+
x = F.gelu(x, approximate="tanh")
|
| 118 |
+
x = self.fc2(x)
|
| 119 |
+
x = self.dropout(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TransformerBlock(nn.Module):
|
| 124 |
+
def __init__(self, cfg: _Cfg) -> None:
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.norm1 = RMSNorm(cfg.d_model, cfg.rms_norm_eps)
|
| 127 |
+
self.attn = CausalSelfAttention(cfg)
|
| 128 |
+
self.norm2 = RMSNorm(cfg.d_model, cfg.rms_norm_eps)
|
| 129 |
+
self.ffn = FeedForward(cfg)
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
x = x + self.attn(self.norm1(x))
|
| 133 |
+
x = x + self.ffn(self.norm2(x))
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MindiForCausalLM(PreTrainedModel):
|
| 138 |
+
config_class = MindiConfig
|
| 139 |
+
base_model_prefix = "mindi"
|
| 140 |
+
supports_gradient_checkpointing = False
|
| 141 |
+
|
| 142 |
+
def __init__(self, config: MindiConfig):
|
| 143 |
+
super().__init__(config)
|
| 144 |
+
cfg = _Cfg(
|
| 145 |
+
vocab_size=config.vocab_size,
|
| 146 |
+
max_seq_len=config.max_seq_len,
|
| 147 |
+
d_model=config.d_model,
|
| 148 |
+
n_layers=config.n_layers,
|
| 149 |
+
n_heads=config.n_heads,
|
| 150 |
+
d_ff=config.d_ff,
|
| 151 |
+
dropout=config.dropout,
|
| 152 |
+
tie_embeddings=config.tie_embeddings,
|
| 153 |
+
init_std=config.init_std,
|
| 154 |
+
rms_norm_eps=config.rms_norm_eps,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 158 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 159 |
+
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
|
| 160 |
+
self.norm_final = RMSNorm(cfg.d_model, cfg.rms_norm_eps)
|
| 161 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 162 |
+
|
| 163 |
+
if cfg.tie_embeddings:
|
| 164 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 165 |
+
|
| 166 |
+
self.post_init()
|
| 167 |
+
|
| 168 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 169 |
+
if isinstance(module, nn.Linear):
|
| 170 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 171 |
+
elif isinstance(module, nn.Embedding):
|
| 172 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 173 |
+
|
| 174 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 175 |
+
return self.embed_tokens
|
| 176 |
+
|
| 177 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 178 |
+
self.embed_tokens = value
|
| 179 |
+
|
| 180 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 181 |
+
return self.lm_head
|
| 182 |
+
|
| 183 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 184 |
+
self.lm_head = new_embeddings
|
| 185 |
+
|
| 186 |
+
def forward(
|
| 187 |
+
self,
|
| 188 |
+
input_ids: torch.Tensor,
|
| 189 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 190 |
+
labels: Optional[torch.Tensor] = None,
|
| 191 |
+
**kwargs,
|
| 192 |
+
) -> CausalLMOutputWithPast:
|
| 193 |
+
del attention_mask, kwargs
|
| 194 |
+
|
| 195 |
+
x = self.embed_tokens(input_ids)
|
| 196 |
+
x = self.dropout(x)
|
| 197 |
+
|
| 198 |
+
for block in self.blocks:
|
| 199 |
+
x = block(x)
|
| 200 |
+
|
| 201 |
+
x = self.norm_final(x)
|
| 202 |
+
logits = self.lm_head(x)
|
| 203 |
+
|
| 204 |
+
loss = None
|
| 205 |
+
if labels is not None:
|
| 206 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 207 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 208 |
+
loss = F.cross_entropy(
|
| 209 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 210 |
+
shift_labels.view(-1),
|
| 211 |
+
ignore_index=-100,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs):
|
| 218 |
+
del kwargs
|
| 219 |
+
return {"input_ids": input_ids}
|
final_model/tokenization_mindi.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face tokenizer class for MINDI 1.0 420M.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from transformers import PreTrainedTokenizerFast
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MindiTokenizer(PreTrainedTokenizerFast):
|
| 10 |
+
vocab_files_names = {"tokenizer_file": "tokenizer.json"}
|
| 11 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
| 15 |
+
if kwargs.get("tokenizer_file") is None:
|
| 16 |
+
local_candidate = Path(str(pretrained_model_name_or_path)) / "tokenizer.json"
|
| 17 |
+
if local_candidate.exists():
|
| 18 |
+
kwargs["tokenizer_file"] = str(local_candidate)
|
| 19 |
+
return super().from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
| 20 |
+
|
| 21 |
+
def __init__(self, tokenizer_file=None, **kwargs):
|
| 22 |
+
name_or_path = kwargs.pop("name_or_path", None)
|
| 23 |
+
if tokenizer_file is None and name_or_path is not None:
|
| 24 |
+
candidate = Path(name_or_path) / "tokenizer.json"
|
| 25 |
+
if candidate.exists():
|
| 26 |
+
tokenizer_file = str(candidate)
|
| 27 |
+
if tokenizer_file is None:
|
| 28 |
+
tokenizer_file = str(Path(__file__).resolve().parent / "tokenizer.json")
|
| 29 |
+
kwargs.setdefault("bos_token", "<BOS>")
|
| 30 |
+
kwargs.setdefault("eos_token", "<EOS>")
|
| 31 |
+
kwargs.setdefault("unk_token", "<UNK>")
|
| 32 |
+
kwargs.setdefault("pad_token", "<PAD>")
|
| 33 |
+
super().__init__(tokenizer_file=tokenizer_file, **kwargs)
|
final_model/tokenizer.json
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0",
|
| 3 |
+
"truncation": null,
|
| 4 |
+
"padding": null,
|
| 5 |
+
"added_tokens": [
|
| 6 |
+
{
|
| 7 |
+
"id": 0,
|
| 8 |
+
"content": "<PAD>",
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"lstrip": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"special": true
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": 1,
|
| 17 |
+
"content": "<UNK>",
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"special": true
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"id": 2,
|
| 26 |
+
"content": "<BOS>",
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"special": true
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 3,
|
| 35 |
+
"content": "<EOS>",
|
| 36 |
+
"single_word": false,
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"rstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"special": true
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": 4,
|
| 44 |
+
"content": "<NL>",
|
| 45 |
+
"single_word": false,
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"special": true
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": 5,
|
| 53 |
+
"content": "<INDENT>",
|
| 54 |
+
"single_word": false,
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"id": 6,
|
| 62 |
+
"content": "<DEDENT>",
|
| 63 |
+
"single_word": false,
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"normalized": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"id": 7,
|
| 71 |
+
"content": "<PROMPT>",
|
| 72 |
+
"single_word": false,
|
| 73 |
+
"lstrip": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"normalized": false,
|
| 76 |
+
"special": true
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"id": 8,
|
| 80 |
+
"content": "<CODE>",
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"lstrip": false,
|
| 83 |
+
"rstrip": false,
|
| 84 |
+
"normalized": false,
|
| 85 |
+
"special": true
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"id": 9,
|
| 89 |
+
"content": "<PYTHON>",
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"lstrip": false,
|
| 92 |
+
"rstrip": false,
|
| 93 |
+
"normalized": false,
|
| 94 |
+
"special": true
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"id": 10,
|
| 98 |
+
"content": "<JAVASCRIPT>",
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"lstrip": false,
|
| 101 |
+
"rstrip": false,
|
| 102 |
+
"normalized": false,
|
| 103 |
+
"special": true
|
| 104 |
+
}
|
| 105 |
+
],
|
| 106 |
+
"normalizer": {
|
| 107 |
+
"type": "Sequence",
|
| 108 |
+
"normalizers": [
|
| 109 |
+
{
|
| 110 |
+
"type": "NFKC"
|
| 111 |
+
}
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
"pre_tokenizer": {
|
| 115 |
+
"type": "Sequence",
|
| 116 |
+
"pretokenizers": [
|
| 117 |
+
{
|
| 118 |
+
"type": "Split",
|
| 119 |
+
"pattern": {
|
| 120 |
+
"Regex": "(==|!=|<=|>=|:=|->|=>|\\+\\+|--|\\+=|-=|\\*=|/=|//=|%=|\\*\\*|&&|\\|\\||<<|>>)"
|
| 121 |
+
},
|
| 122 |
+
"behavior": "Isolated",
|
| 123 |
+
"invert": false
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"type": "Split",
|
| 127 |
+
"pattern": {
|
| 128 |
+
"Regex": "([()\\[\\]{}.,:;])"
|
| 129 |
+
},
|
| 130 |
+
"behavior": "Isolated",
|
| 131 |
+
"invert": false
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"type": "Metaspace",
|
| 135 |
+
"replacement": "_",
|
| 136 |
+
"prepend_scheme": "always",
|
| 137 |
+
"split": true
|
| 138 |
+
}
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
"post_processor": {
|
| 142 |
+
"type": "TemplateProcessing",
|
| 143 |
+
"single": [
|
| 144 |
+
{
|
| 145 |
+
"SpecialToken": {
|
| 146 |
+
"id": "<BOS>",
|
| 147 |
+
"type_id": 0
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"Sequence": {
|
| 152 |
+
"id": "A",
|
| 153 |
+
"type_id": 0
|
| 154 |
+
}
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"SpecialToken": {
|
| 158 |
+
"id": "<EOS>",
|
| 159 |
+
"type_id": 0
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
],
|
| 163 |
+
"pair": [
|
| 164 |
+
{
|
| 165 |
+
"Sequence": {
|
| 166 |
+
"id": "A",
|
| 167 |
+
"type_id": 0
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"Sequence": {
|
| 172 |
+
"id": "B",
|
| 173 |
+
"type_id": 1
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
],
|
| 177 |
+
"special_tokens": {
|
| 178 |
+
"<BOS>": {
|
| 179 |
+
"id": "<BOS>",
|
| 180 |
+
"ids": [
|
| 181 |
+
2
|
| 182 |
+
],
|
| 183 |
+
"tokens": [
|
| 184 |
+
"<BOS>"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
"<EOS>": {
|
| 188 |
+
"id": "<EOS>",
|
| 189 |
+
"ids": [
|
| 190 |
+
3
|
| 191 |
+
],
|
| 192 |
+
"tokens": [
|
| 193 |
+
"<EOS>"
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
},
|
| 198 |
+
"decoder": {
|
| 199 |
+
"type": "BPEDecoder",
|
| 200 |
+
"suffix": "</w>"
|
| 201 |
+
},
|
| 202 |
+
"model": {
|
| 203 |
+
"type": "BPE",
|
| 204 |
+
"dropout": null,
|
| 205 |
+
"unk_token": "<UNK>",
|
| 206 |
+
"continuing_subword_prefix": null,
|
| 207 |
+
"end_of_word_suffix": null,
|
| 208 |
+
"fuse_unk": false,
|
| 209 |
+
"byte_fallback": false,
|
| 210 |
+
"ignore_merges": false,
|
| 211 |
+
"vocab": {
|
| 212 |
+
"<PAD>": 0,
|
| 213 |
+
"<UNK>": 1,
|
| 214 |
+
"<BOS>": 2,
|
| 215 |
+
"<EOS>": 3,
|
| 216 |
+
"<NL>": 4,
|
| 217 |
+
"<INDENT>": 5,
|
| 218 |
+
"<DEDENT>": 6,
|
| 219 |
+
"<PROMPT>": 7,
|
| 220 |
+
"<CODE>": 8,
|
| 221 |
+
"<PYTHON>": 9,
|
| 222 |
+
"<JAVASCRIPT>": 10,
|
| 223 |
+
"(": 11,
|
| 224 |
+
")": 12,
|
| 225 |
+
"+": 13,
|
| 226 |
+
",": 14,
|
| 227 |
+
".": 15,
|
| 228 |
+
"0": 16,
|
| 229 |
+
"4": 17,
|
| 230 |
+
"5": 18,
|
| 231 |
+
":": 19,
|
| 232 |
+
";": 20,
|
| 233 |
+
"<": 21,
|
| 234 |
+
"=": 22,
|
| 235 |
+
">": 23,
|
| 236 |
+
"A": 24,
|
| 237 |
+
"C": 25,
|
| 238 |
+
"D": 26,
|
| 239 |
+
"E": 27,
|
| 240 |
+
"F": 28,
|
| 241 |
+
"H": 29,
|
| 242 |
+
"I": 30,
|
| 243 |
+
"J": 31,
|
| 244 |
+
"L": 32,
|
| 245 |
+
"M": 33,
|
| 246 |
+
"N": 34,
|
| 247 |
+
"O": 35,
|
| 248 |
+
"P": 36,
|
| 249 |
+
"R": 37,
|
| 250 |
+
"S": 38,
|
| 251 |
+
"T": 39,
|
| 252 |
+
"V": 40,
|
| 253 |
+
"W": 41,
|
| 254 |
+
"Y": 42,
|
| 255 |
+
"_": 43,
|
| 256 |
+
"a": 44,
|
| 257 |
+
"b": 45,
|
| 258 |
+
"c": 46,
|
| 259 |
+
"d": 47,
|
| 260 |
+
"e": 48,
|
| 261 |
+
"f": 49,
|
| 262 |
+
"g": 50,
|
| 263 |
+
"h": 51,
|
| 264 |
+
"i": 52,
|
| 265 |
+
"l": 53,
|
| 266 |
+
"m": 54,
|
| 267 |
+
"n": 55,
|
| 268 |
+
"o": 56,
|
| 269 |
+
"p": 57,
|
| 270 |
+
"r": 58,
|
| 271 |
+
"s": 59,
|
| 272 |
+
"t": 60,
|
| 273 |
+
"u": 61,
|
| 274 |
+
"v": 62,
|
| 275 |
+
"w": 63,
|
| 276 |
+
"x": 64,
|
| 277 |
+
"y": 65,
|
| 278 |
+
"{": 66,
|
| 279 |
+
"}": 67,
|
| 280 |
+
"_<": 68,
|
| 281 |
+
"DE": 69,
|
| 282 |
+
"T>": 70,
|
| 283 |
+
"_a": 71,
|
| 284 |
+
"L>": 72,
|
| 285 |
+
"NL>": 73,
|
| 286 |
+
"_<NL>": 74,
|
| 287 |
+
"NT>": 75,
|
| 288 |
+
"_t": 76,
|
| 289 |
+
"DENT>": 77,
|
| 290 |
+
"_i": 78,
|
| 291 |
+
"PT>": 79,
|
| 292 |
+
"_(": 80,
|
| 293 |
+
"_)": 81,
|
| 294 |
+
"on": 82,
|
| 295 |
+
"_<P": 83,
|
| 296 |
+
"_f": 84,
|
| 297 |
+
"_l": 85,
|
| 298 |
+
"re": 86,
|
| 299 |
+
"ri": 87,
|
| 300 |
+
"CO": 88,
|
| 301 |
+
"IN": 89,
|
| 302 |
+
"MPT>": 90,
|
| 303 |
+
"OMPT>": 91,
|
| 304 |
+
"ROMPT>": 92,
|
| 305 |
+
"_;": 93,
|
| 306 |
+
"_b": 94,
|
| 307 |
+
"at": 95,
|
| 308 |
+
"_<DE": 96,
|
| 309 |
+
"_<CO": 97,
|
| 310 |
+
"_<IN": 98,
|
| 311 |
+
"DE>": 99,
|
| 312 |
+
"_to": 100,
|
| 313 |
+
"_<PROMPT>": 101,
|
| 314 |
+
"_lo": 102,
|
| 315 |
+
"_<DEDENT>": 103,
|
| 316 |
+
"_<CODE>": 104,
|
| 317 |
+
"_<INDENT>": 105,
|
| 318 |
+
"_+": 106,
|
| 319 |
+
"_0": 107,
|
| 320 |
+
"_re": 108,
|
| 321 |
+
"ct": 109,
|
| 322 |
+
"dd": 110,
|
| 323 |
+
"ion": 111,
|
| 324 |
+
"nct": 112,
|
| 325 |
+
"rn": 113,
|
| 326 |
+
"tu": 114,
|
| 327 |
+
"unct": 115,
|
| 328 |
+
"va": 116,
|
| 329 |
+
"_add": 117,
|
| 330 |
+
"_th": 118,
|
| 331 |
+
"_funct": 119,
|
| 332 |
+
"_retu": 120,
|
| 333 |
+
"_function": 121,
|
| 334 |
+
"_return": 122,
|
| 335 |
+
"AS": 123,
|
| 336 |
+
"AV": 124,
|
| 337 |
+
"CR": 125,
|
| 338 |
+
"Cre": 126,
|
| 339 |
+
"HO": 127,
|
| 340 |
+
"IPT>": 128,
|
| 341 |
+
"Ja": 129,
|
| 342 |
+
"JAV": 130,
|
| 343 |
+
"N>": 131,
|
| 344 |
+
"Py": 132,
|
| 345 |
+
"Sc": 133,
|
| 346 |
+
"THO": 134,
|
| 347 |
+
"YTHO": 135,
|
| 348 |
+
"_,": 136,
|
| 349 |
+
"_4": 137,
|
| 350 |
+
"_5": 138,
|
| 351 |
+
"_:": 139,
|
| 352 |
+
"_p": 140,
|
| 353 |
+
"_{": 141,
|
| 354 |
+
"_}": 142,
|
| 355 |
+
"_Cre": 143,
|
| 356 |
+
"_Ja": 144,
|
| 357 |
+
"_Py": 145,
|
| 358 |
+
"hon": 146,
|
| 359 |
+
"nt": 147,
|
| 360 |
+
"op": 148,
|
| 361 |
+
"or": 149,
|
| 362 |
+
"pt": 150,
|
| 363 |
+
"thon": 151,
|
| 364 |
+
"_<JAV": 152,
|
| 365 |
+
"_<PYTHO": 153,
|
| 366 |
+
"_for": 154,
|
| 367 |
+
"rint": 155,
|
| 368 |
+
"ript": 156,
|
| 369 |
+
"ate": 157,
|
| 370 |
+
"_log": 158,
|
| 371 |
+
"_loop": 159,
|
| 372 |
+
"vaSc": 160,
|
| 373 |
+
"_that": 161,
|
| 374 |
+
"ASCR": 162,
|
| 375 |
+
"_print": 163,
|
| 376 |
+
"_Create": 164,
|
| 377 |
+
"_JavaSc": 165,
|
| 378 |
+
"_Python": 166,
|
| 379 |
+
"_<JAVASCR": 167,
|
| 380 |
+
"_<PYTHON>": 168,
|
| 381 |
+
"_JavaScript": 169,
|
| 382 |
+
"_<JAVASCRIPT>": 170
|
| 383 |
+
},
|
| 384 |
+
"merges": [
|
| 385 |
+
[
|
| 386 |
+
"_",
|
| 387 |
+
"<"
|
| 388 |
+
],
|
| 389 |
+
[
|
| 390 |
+
"D",
|
| 391 |
+
"E"
|
| 392 |
+
],
|
| 393 |
+
[
|
| 394 |
+
"T",
|
| 395 |
+
">"
|
| 396 |
+
],
|
| 397 |
+
[
|
| 398 |
+
"_",
|
| 399 |
+
"a"
|
| 400 |
+
],
|
| 401 |
+
[
|
| 402 |
+
"L",
|
| 403 |
+
">"
|
| 404 |
+
],
|
| 405 |
+
[
|
| 406 |
+
"N",
|
| 407 |
+
"L>"
|
| 408 |
+
],
|
| 409 |
+
[
|
| 410 |
+
"_<",
|
| 411 |
+
"NL>"
|
| 412 |
+
],
|
| 413 |
+
[
|
| 414 |
+
"N",
|
| 415 |
+
"T>"
|
| 416 |
+
],
|
| 417 |
+
[
|
| 418 |
+
"_",
|
| 419 |
+
"t"
|
| 420 |
+
],
|
| 421 |
+
[
|
| 422 |
+
"DE",
|
| 423 |
+
"NT>"
|
| 424 |
+
],
|
| 425 |
+
[
|
| 426 |
+
"_",
|
| 427 |
+
"i"
|
| 428 |
+
],
|
| 429 |
+
[
|
| 430 |
+
"P",
|
| 431 |
+
"T>"
|
| 432 |
+
],
|
| 433 |
+
[
|
| 434 |
+
"_",
|
| 435 |
+
"("
|
| 436 |
+
],
|
| 437 |
+
[
|
| 438 |
+
"_",
|
| 439 |
+
")"
|
| 440 |
+
],
|
| 441 |
+
[
|
| 442 |
+
"o",
|
| 443 |
+
"n"
|
| 444 |
+
],
|
| 445 |
+
[
|
| 446 |
+
"_<",
|
| 447 |
+
"P"
|
| 448 |
+
],
|
| 449 |
+
[
|
| 450 |
+
"_",
|
| 451 |
+
"f"
|
| 452 |
+
],
|
| 453 |
+
[
|
| 454 |
+
"_",
|
| 455 |
+
"l"
|
| 456 |
+
],
|
| 457 |
+
[
|
| 458 |
+
"r",
|
| 459 |
+
"e"
|
| 460 |
+
],
|
| 461 |
+
[
|
| 462 |
+
"r",
|
| 463 |
+
"i"
|
| 464 |
+
],
|
| 465 |
+
[
|
| 466 |
+
"C",
|
| 467 |
+
"O"
|
| 468 |
+
],
|
| 469 |
+
[
|
| 470 |
+
"I",
|
| 471 |
+
"N"
|
| 472 |
+
],
|
| 473 |
+
[
|
| 474 |
+
"M",
|
| 475 |
+
"PT>"
|
| 476 |
+
],
|
| 477 |
+
[
|
| 478 |
+
"O",
|
| 479 |
+
"MPT>"
|
| 480 |
+
],
|
| 481 |
+
[
|
| 482 |
+
"R",
|
| 483 |
+
"OMPT>"
|
| 484 |
+
],
|
| 485 |
+
[
|
| 486 |
+
"_",
|
| 487 |
+
";"
|
| 488 |
+
],
|
| 489 |
+
[
|
| 490 |
+
"_",
|
| 491 |
+
"b"
|
| 492 |
+
],
|
| 493 |
+
[
|
| 494 |
+
"a",
|
| 495 |
+
"t"
|
| 496 |
+
],
|
| 497 |
+
[
|
| 498 |
+
"_<",
|
| 499 |
+
"DE"
|
| 500 |
+
],
|
| 501 |
+
[
|
| 502 |
+
"_<",
|
| 503 |
+
"CO"
|
| 504 |
+
],
|
| 505 |
+
[
|
| 506 |
+
"_<",
|
| 507 |
+
"IN"
|
| 508 |
+
],
|
| 509 |
+
[
|
| 510 |
+
"DE",
|
| 511 |
+
">"
|
| 512 |
+
],
|
| 513 |
+
[
|
| 514 |
+
"_t",
|
| 515 |
+
"o"
|
| 516 |
+
],
|
| 517 |
+
[
|
| 518 |
+
"_<P",
|
| 519 |
+
"ROMPT>"
|
| 520 |
+
],
|
| 521 |
+
[
|
| 522 |
+
"_l",
|
| 523 |
+
"o"
|
| 524 |
+
],
|
| 525 |
+
[
|
| 526 |
+
"_<DE",
|
| 527 |
+
"DENT>"
|
| 528 |
+
],
|
| 529 |
+
[
|
| 530 |
+
"_<CO",
|
| 531 |
+
"DE>"
|
| 532 |
+
],
|
| 533 |
+
[
|
| 534 |
+
"_<IN",
|
| 535 |
+
"DENT>"
|
| 536 |
+
],
|
| 537 |
+
[
|
| 538 |
+
"_",
|
| 539 |
+
"+"
|
| 540 |
+
],
|
| 541 |
+
[
|
| 542 |
+
"_",
|
| 543 |
+
"0"
|
| 544 |
+
],
|
| 545 |
+
[
|
| 546 |
+
"_",
|
| 547 |
+
"re"
|
| 548 |
+
],
|
| 549 |
+
[
|
| 550 |
+
"c",
|
| 551 |
+
"t"
|
| 552 |
+
],
|
| 553 |
+
[
|
| 554 |
+
"d",
|
| 555 |
+
"d"
|
| 556 |
+
],
|
| 557 |
+
[
|
| 558 |
+
"i",
|
| 559 |
+
"on"
|
| 560 |
+
],
|
| 561 |
+
[
|
| 562 |
+
"n",
|
| 563 |
+
"ct"
|
| 564 |
+
],
|
| 565 |
+
[
|
| 566 |
+
"r",
|
| 567 |
+
"n"
|
| 568 |
+
],
|
| 569 |
+
[
|
| 570 |
+
"t",
|
| 571 |
+
"u"
|
| 572 |
+
],
|
| 573 |
+
[
|
| 574 |
+
"u",
|
| 575 |
+
"nct"
|
| 576 |
+
],
|
| 577 |
+
[
|
| 578 |
+
"v",
|
| 579 |
+
"a"
|
| 580 |
+
],
|
| 581 |
+
[
|
| 582 |
+
"_a",
|
| 583 |
+
"dd"
|
| 584 |
+
],
|
| 585 |
+
[
|
| 586 |
+
"_t",
|
| 587 |
+
"h"
|
| 588 |
+
],
|
| 589 |
+
[
|
| 590 |
+
"_f",
|
| 591 |
+
"unct"
|
| 592 |
+
],
|
| 593 |
+
[
|
| 594 |
+
"_re",
|
| 595 |
+
"tu"
|
| 596 |
+
],
|
| 597 |
+
[
|
| 598 |
+
"_funct",
|
| 599 |
+
"ion"
|
| 600 |
+
],
|
| 601 |
+
[
|
| 602 |
+
"_retu",
|
| 603 |
+
"rn"
|
| 604 |
+
],
|
| 605 |
+
[
|
| 606 |
+
"A",
|
| 607 |
+
"S"
|
| 608 |
+
],
|
| 609 |
+
[
|
| 610 |
+
"A",
|
| 611 |
+
"V"
|
| 612 |
+
],
|
| 613 |
+
[
|
| 614 |
+
"C",
|
| 615 |
+
"R"
|
| 616 |
+
],
|
| 617 |
+
[
|
| 618 |
+
"C",
|
| 619 |
+
"re"
|
| 620 |
+
],
|
| 621 |
+
[
|
| 622 |
+
"H",
|
| 623 |
+
"O"
|
| 624 |
+
],
|
| 625 |
+
[
|
| 626 |
+
"I",
|
| 627 |
+
"PT>"
|
| 628 |
+
],
|
| 629 |
+
[
|
| 630 |
+
"J",
|
| 631 |
+
"a"
|
| 632 |
+
],
|
| 633 |
+
[
|
| 634 |
+
"J",
|
| 635 |
+
"AV"
|
| 636 |
+
],
|
| 637 |
+
[
|
| 638 |
+
"N",
|
| 639 |
+
">"
|
| 640 |
+
],
|
| 641 |
+
[
|
| 642 |
+
"P",
|
| 643 |
+
"y"
|
| 644 |
+
],
|
| 645 |
+
[
|
| 646 |
+
"S",
|
| 647 |
+
"c"
|
| 648 |
+
],
|
| 649 |
+
[
|
| 650 |
+
"T",
|
| 651 |
+
"HO"
|
| 652 |
+
],
|
| 653 |
+
[
|
| 654 |
+
"Y",
|
| 655 |
+
"THO"
|
| 656 |
+
],
|
| 657 |
+
[
|
| 658 |
+
"_",
|
| 659 |
+
","
|
| 660 |
+
],
|
| 661 |
+
[
|
| 662 |
+
"_",
|
| 663 |
+
"4"
|
| 664 |
+
],
|
| 665 |
+
[
|
| 666 |
+
"_",
|
| 667 |
+
"5"
|
| 668 |
+
],
|
| 669 |
+
[
|
| 670 |
+
"_",
|
| 671 |
+
":"
|
| 672 |
+
],
|
| 673 |
+
[
|
| 674 |
+
"_",
|
| 675 |
+
"p"
|
| 676 |
+
],
|
| 677 |
+
[
|
| 678 |
+
"_",
|
| 679 |
+
"{"
|
| 680 |
+
],
|
| 681 |
+
[
|
| 682 |
+
"_",
|
| 683 |
+
"}"
|
| 684 |
+
],
|
| 685 |
+
[
|
| 686 |
+
"_",
|
| 687 |
+
"Cre"
|
| 688 |
+
],
|
| 689 |
+
[
|
| 690 |
+
"_",
|
| 691 |
+
"Ja"
|
| 692 |
+
],
|
| 693 |
+
[
|
| 694 |
+
"_",
|
| 695 |
+
"Py"
|
| 696 |
+
],
|
| 697 |
+
[
|
| 698 |
+
"h",
|
| 699 |
+
"on"
|
| 700 |
+
],
|
| 701 |
+
[
|
| 702 |
+
"n",
|
| 703 |
+
"t"
|
| 704 |
+
],
|
| 705 |
+
[
|
| 706 |
+
"o",
|
| 707 |
+
"p"
|
| 708 |
+
],
|
| 709 |
+
[
|
| 710 |
+
"o",
|
| 711 |
+
"r"
|
| 712 |
+
],
|
| 713 |
+
[
|
| 714 |
+
"p",
|
| 715 |
+
"t"
|
| 716 |
+
],
|
| 717 |
+
[
|
| 718 |
+
"t",
|
| 719 |
+
"hon"
|
| 720 |
+
],
|
| 721 |
+
[
|
| 722 |
+
"_<",
|
| 723 |
+
"JAV"
|
| 724 |
+
],
|
| 725 |
+
[
|
| 726 |
+
"_<P",
|
| 727 |
+
"YTHO"
|
| 728 |
+
],
|
| 729 |
+
[
|
| 730 |
+
"_f",
|
| 731 |
+
"or"
|
| 732 |
+
],
|
| 733 |
+
[
|
| 734 |
+
"ri",
|
| 735 |
+
"nt"
|
| 736 |
+
],
|
| 737 |
+
[
|
| 738 |
+
"ri",
|
| 739 |
+
"pt"
|
| 740 |
+
],
|
| 741 |
+
[
|
| 742 |
+
"at",
|
| 743 |
+
"e"
|
| 744 |
+
],
|
| 745 |
+
[
|
| 746 |
+
"_lo",
|
| 747 |
+
"g"
|
| 748 |
+
],
|
| 749 |
+
[
|
| 750 |
+
"_lo",
|
| 751 |
+
"op"
|
| 752 |
+
],
|
| 753 |
+
[
|
| 754 |
+
"va",
|
| 755 |
+
"Sc"
|
| 756 |
+
],
|
| 757 |
+
[
|
| 758 |
+
"_th",
|
| 759 |
+
"at"
|
| 760 |
+
],
|
| 761 |
+
[
|
| 762 |
+
"AS",
|
| 763 |
+
"CR"
|
| 764 |
+
],
|
| 765 |
+
[
|
| 766 |
+
"_p",
|
| 767 |
+
"rint"
|
| 768 |
+
],
|
| 769 |
+
[
|
| 770 |
+
"_Cre",
|
| 771 |
+
"ate"
|
| 772 |
+
],
|
| 773 |
+
[
|
| 774 |
+
"_Ja",
|
| 775 |
+
"vaSc"
|
| 776 |
+
],
|
| 777 |
+
[
|
| 778 |
+
"_Py",
|
| 779 |
+
"thon"
|
| 780 |
+
],
|
| 781 |
+
[
|
| 782 |
+
"_<JAV",
|
| 783 |
+
"ASCR"
|
| 784 |
+
],
|
| 785 |
+
[
|
| 786 |
+
"_<PYTHO",
|
| 787 |
+
"N>"
|
| 788 |
+
],
|
| 789 |
+
[
|
| 790 |
+
"_JavaSc",
|
| 791 |
+
"ript"
|
| 792 |
+
],
|
| 793 |
+
[
|
| 794 |
+
"_<JAVASCR",
|
| 795 |
+
"IPT>"
|
| 796 |
+
]
|
| 797 |
+
]
|
| 798 |
+
}
|
| 799 |
+
}
|
final_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoTokenizer": [
|
| 4 |
+
"tokenization_mindi.MindiTokenizer",
|
| 5 |
+
null
|
| 6 |
+
]
|
| 7 |
+
},
|
| 8 |
+
"backend": "tokenizers",
|
| 9 |
+
"bos_token": "<BOS>",
|
| 10 |
+
"eos_token": "<EOS>",
|
| 11 |
+
"is_local": true,
|
| 12 |
+
"model_max_length": 2048,
|
| 13 |
+
"pad_token": "<PAD>",
|
| 14 |
+
"padding_side": "right",
|
| 15 |
+
"tokenizer_class": "MindiTokenizer",
|
| 16 |
+
"truncation_side": "right",
|
| 17 |
+
"unk_token": "<UNK>",
|
| 18 |
+
"vocab": {
|
| 19 |
+
"(": 11,
|
| 20 |
+
")": 12,
|
| 21 |
+
"+": 13,
|
| 22 |
+
",": 14,
|
| 23 |
+
".": 15,
|
| 24 |
+
"0": 16,
|
| 25 |
+
"4": 17,
|
| 26 |
+
"5": 18,
|
| 27 |
+
":": 19,
|
| 28 |
+
";": 20,
|
| 29 |
+
"<": 21,
|
| 30 |
+
"<BOS>": 2,
|
| 31 |
+
"<CODE>": 8,
|
| 32 |
+
"<DEDENT>": 6,
|
| 33 |
+
"<EOS>": 3,
|
| 34 |
+
"<INDENT>": 5,
|
| 35 |
+
"<JAVASCRIPT>": 10,
|
| 36 |
+
"<NL>": 4,
|
| 37 |
+
"<PAD>": 0,
|
| 38 |
+
"<PROMPT>": 7,
|
| 39 |
+
"<PYTHON>": 9,
|
| 40 |
+
"<UNK>": 1,
|
| 41 |
+
"=": 22,
|
| 42 |
+
">": 23,
|
| 43 |
+
"A": 24,
|
| 44 |
+
"AS": 123,
|
| 45 |
+
"ASCR": 162,
|
| 46 |
+
"AV": 124,
|
| 47 |
+
"C": 25,
|
| 48 |
+
"CO": 88,
|
| 49 |
+
"CR": 125,
|
| 50 |
+
"Cre": 126,
|
| 51 |
+
"D": 26,
|
| 52 |
+
"DE": 69,
|
| 53 |
+
"DE>": 99,
|
| 54 |
+
"DENT>": 77,
|
| 55 |
+
"E": 27,
|
| 56 |
+
"F": 28,
|
| 57 |
+
"H": 29,
|
| 58 |
+
"HO": 127,
|
| 59 |
+
"I": 30,
|
| 60 |
+
"IN": 89,
|
| 61 |
+
"IPT>": 128,
|
| 62 |
+
"J": 31,
|
| 63 |
+
"JAV": 130,
|
| 64 |
+
"Ja": 129,
|
| 65 |
+
"L": 32,
|
| 66 |
+
"L>": 72,
|
| 67 |
+
"M": 33,
|
| 68 |
+
"MPT>": 90,
|
| 69 |
+
"N": 34,
|
| 70 |
+
"N>": 131,
|
| 71 |
+
"NL>": 73,
|
| 72 |
+
"NT>": 75,
|
| 73 |
+
"O": 35,
|
| 74 |
+
"OMPT>": 91,
|
| 75 |
+
"P": 36,
|
| 76 |
+
"PT>": 79,
|
| 77 |
+
"Py": 132,
|
| 78 |
+
"R": 37,
|
| 79 |
+
"ROMPT>": 92,
|
| 80 |
+
"S": 38,
|
| 81 |
+
"Sc": 133,
|
| 82 |
+
"T": 39,
|
| 83 |
+
"T>": 70,
|
| 84 |
+
"THO": 134,
|
| 85 |
+
"V": 40,
|
| 86 |
+
"W": 41,
|
| 87 |
+
"Y": 42,
|
| 88 |
+
"YTHO": 135,
|
| 89 |
+
"_": 43,
|
| 90 |
+
"_(": 80,
|
| 91 |
+
"_)": 81,
|
| 92 |
+
"_+": 106,
|
| 93 |
+
"_,": 136,
|
| 94 |
+
"_0": 107,
|
| 95 |
+
"_4": 137,
|
| 96 |
+
"_5": 138,
|
| 97 |
+
"_:": 139,
|
| 98 |
+
"_;": 93,
|
| 99 |
+
"_<": 68,
|
| 100 |
+
"_<CO": 97,
|
| 101 |
+
"_<CODE>": 104,
|
| 102 |
+
"_<DE": 96,
|
| 103 |
+
"_<DEDENT>": 103,
|
| 104 |
+
"_<IN": 98,
|
| 105 |
+
"_<INDENT>": 105,
|
| 106 |
+
"_<JAV": 152,
|
| 107 |
+
"_<JAVASCR": 167,
|
| 108 |
+
"_<JAVASCRIPT>": 170,
|
| 109 |
+
"_<NL>": 74,
|
| 110 |
+
"_<P": 83,
|
| 111 |
+
"_<PROMPT>": 101,
|
| 112 |
+
"_<PYTHO": 153,
|
| 113 |
+
"_<PYTHON>": 168,
|
| 114 |
+
"_Cre": 143,
|
| 115 |
+
"_Create": 164,
|
| 116 |
+
"_Ja": 144,
|
| 117 |
+
"_JavaSc": 165,
|
| 118 |
+
"_JavaScript": 169,
|
| 119 |
+
"_Py": 145,
|
| 120 |
+
"_Python": 166,
|
| 121 |
+
"_a": 71,
|
| 122 |
+
"_add": 117,
|
| 123 |
+
"_b": 94,
|
| 124 |
+
"_f": 84,
|
| 125 |
+
"_for": 154,
|
| 126 |
+
"_funct": 119,
|
| 127 |
+
"_function": 121,
|
| 128 |
+
"_i": 78,
|
| 129 |
+
"_l": 85,
|
| 130 |
+
"_lo": 102,
|
| 131 |
+
"_log": 158,
|
| 132 |
+
"_loop": 159,
|
| 133 |
+
"_p": 140,
|
| 134 |
+
"_print": 163,
|
| 135 |
+
"_re": 108,
|
| 136 |
+
"_retu": 120,
|
| 137 |
+
"_return": 122,
|
| 138 |
+
"_t": 76,
|
| 139 |
+
"_th": 118,
|
| 140 |
+
"_that": 161,
|
| 141 |
+
"_to": 100,
|
| 142 |
+
"_{": 141,
|
| 143 |
+
"_}": 142,
|
| 144 |
+
"a": 44,
|
| 145 |
+
"at": 95,
|
| 146 |
+
"ate": 157,
|
| 147 |
+
"b": 45,
|
| 148 |
+
"c": 46,
|
| 149 |
+
"ct": 109,
|
| 150 |
+
"d": 47,
|
| 151 |
+
"dd": 110,
|
| 152 |
+
"e": 48,
|
| 153 |
+
"f": 49,
|
| 154 |
+
"g": 50,
|
| 155 |
+
"h": 51,
|
| 156 |
+
"hon": 146,
|
| 157 |
+
"i": 52,
|
| 158 |
+
"ion": 111,
|
| 159 |
+
"l": 53,
|
| 160 |
+
"m": 54,
|
| 161 |
+
"n": 55,
|
| 162 |
+
"nct": 112,
|
| 163 |
+
"nt": 147,
|
| 164 |
+
"o": 56,
|
| 165 |
+
"on": 82,
|
| 166 |
+
"op": 148,
|
| 167 |
+
"or": 149,
|
| 168 |
+
"p": 57,
|
| 169 |
+
"pt": 150,
|
| 170 |
+
"r": 58,
|
| 171 |
+
"re": 86,
|
| 172 |
+
"ri": 87,
|
| 173 |
+
"rint": 155,
|
| 174 |
+
"ript": 156,
|
| 175 |
+
"rn": 113,
|
| 176 |
+
"s": 59,
|
| 177 |
+
"t": 60,
|
| 178 |
+
"thon": 151,
|
| 179 |
+
"tu": 114,
|
| 180 |
+
"u": 61,
|
| 181 |
+
"unct": 115,
|
| 182 |
+
"v": 62,
|
| 183 |
+
"va": 116,
|
| 184 |
+
"vaSc": 160,
|
| 185 |
+
"w": 63,
|
| 186 |
+
"x": 64,
|
| 187 |
+
"y": 65,
|
| 188 |
+
"{": 66,
|
| 189 |
+
"}": 67
|
| 190 |
+
}
|
| 191 |
+
}
|
logs/data_fetch.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:250881bf6b7747176c7432a40e84fb3dc4eeca6f9a1a75378ee7e3ccdf662fbf
|
| 3 |
+
size 44778
|
merge.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from peft import PeftModel
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
|
| 4 |
+
base_model_path = "hf_release/MINDI-1.0-420M"
|
| 5 |
+
lora_path = "output/checkpoints/checkpoint-12000"
|
| 6 |
+
|
| 7 |
+
print("Loading base model...")
|
| 8 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 9 |
+
base_model_path,
|
| 10 |
+
trust_remote_code=True
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
print("Loading LoRA...")
|
| 14 |
+
model = PeftModel.from_pretrained(model, lora_path)
|
| 15 |
+
|
| 16 |
+
print("Merging...")
|
| 17 |
+
model = model.merge_and_unload()
|
| 18 |
+
|
| 19 |
+
print("Saving final model...")
|
| 20 |
+
model.save_pretrained("final_model", safe_serialization=False)
|
| 21 |
+
|
| 22 |
+
print("Saving tokenizer...")
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 24 |
+
base_model_path,
|
| 25 |
+
trust_remote_code=True
|
| 26 |
+
)
|
| 27 |
+
tokenizer.save_pretrained("final_model")
|
| 28 |
+
|
| 29 |
+
print("✅ DONE")
|
requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ datasets
|
|
| 3 |
peft
|
| 4 |
accelerate
|
| 5 |
torch
|
|
|
|
|
|
| 3 |
peft
|
| 4 |
accelerate
|
| 5 |
torch
|
| 6 |
+
tqdm
|
test.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2 |
+
|
| 3 |
+
model_path = "final_model"
|
| 4 |
+
|
| 5 |
+
print("Loading tokenizer...")
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 7 |
+
model_path,
|
| 8 |
+
trust_remote_code=True
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
print("Loading model...")
|
| 12 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
+
model_path,
|
| 14 |
+
trust_remote_code=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# 🔥 FIXES
|
| 18 |
+
model = model.float()
|
| 19 |
+
model.config.num_hidden_layers = getattr(model.config, "n_layer", 12)
|
| 20 |
+
model.config.is_encoder_decoder = False
|
| 21 |
+
|
| 22 |
+
prompt = "Write a Python function for binary search"
|
| 23 |
+
|
| 24 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 25 |
+
|
| 26 |
+
print("Generating...")
|
| 27 |
+
output = model.generate(
|
| 28 |
+
**inputs,
|
| 29 |
+
max_new_tokens=200,
|
| 30 |
+
temperature=0.7,
|
| 31 |
+
do_sample=True
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
print("\n=== OUTPUT ===\n")
|
| 35 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|