lumia-tiny / prepare_tiny_data.py
samcheng0's picture
Upload folder using huggingface_hub
a947587 verified
Raw
History Blame Contribute Delete
12.4 kB
"""
Prepare data for Lumia-Tiny training.
Usage:
python3 scripts/prepare_tiny_data.py # use existing data/data.jsonl
python3 scripts/prepare_tiny_data.py --generate 50 # generate N synthetic examples
python3 scripts/prepare_tiny_data.py --hf-sample 100 # sample from HF dataset
python3 scripts/prepare_tiny_data.py --to-jsonl # convert to simple format
"""
import os, sys, json, argparse, random
from pathlib import Path
SYSTEM_PROMPTS = [
"You are a helpful AI assistant who solves problems step by step.",
"You are a precise programming assistant who writes clean, correct code.",
"You are a math tutor who explains concepts clearly.",
"You are a reasoning assistant who thinks through problems carefully.",
]
INSTRUCTIONS = [
"What is 2+2?",
"Explain how a binary search works.",
"Write a Python function to reverse a linked list.",
"What is the capital of France?",
"Explain the concept of recursion.",
"Write a function to check if a string is a palindrome.",
"What is the difference between TCP and UDP?",
"Explain how gradient descent works.",
"Write a quicksort implementation.",
"What is the time complexity of binary search?",
"Explain the concept of overfitting in machine learning.",
"Write a function to find the nth Fibonacci number.",
"What is the difference between a list and a tuple in Python?",
"Explain the CAP theorem.",
"Write a function to merge two sorted arrays.",
"What is the Pythagorean theorem?",
"Explain how HTTP works.",
"Write a function to calculate the factorial of a number.",
"What is the difference between SQL and NoSQL?",
"Explain the concept of polymorphism in OOP.",
"What is 15 * 7?",
"Write a function to find the maximum subarray sum.",
"Explain how DNS resolution works.",
"What is the difference between stack and heap memory?",
"Write a function to detect cycles in a linked list.",
"What is the derivative of x^2?",
"Explain the concept of normalization in databases.",
"Write a function to perform binary search in a sorted array.",
"What is the difference between symmetric and asymmetric encryption?",
"Explain how a hash table works.",
"What is 144 / 12?",
"Write a function to find all prime numbers up to n.",
"Explain the difference between REST and GraphQL.",
"What is the chain rule in calculus?",
"Write a function to serialize a binary tree.",
"What is the difference between HTTP and HTTPS?",
"Explain the concept of attention in transformers.",
"Write a function to compute the edit distance between two strings.",
"What is the integral of 2x?",
"Explain how a blockchain works.",
]
RESPONSES = [
"4",
"Binary search works by repeatedly dividing a sorted array in half, comparing the target value to the middle element, and eliminating the half that cannot contain the target.",
"```python\ndef reverse_linked_list(head):\n prev = None\n curr = head\n while curr:\n next_node = curr.next\n curr.next = prev\n prev = curr\n curr = next_node\n return prev\n```",
"Paris",
"Recursion is a programming technique where a function calls itself to solve a problem by breaking it down into smaller subproblems. Each recursive call works on a smaller instance until reaching a base case.",
"```python\ndef is_palindrome(s):\n s = ''.join(c.lower() for c in s if c.isalnum())\n return s == s[::-1]\n```",
"TCP is connection-oriented and guarantees delivery with error checking. UDP is connectionless and faster but does not guarantee delivery.",
"Gradient descent is an optimization algorithm that iteratively adjusts parameters to minimize a loss function by moving in the direction of the negative gradient.",
"```python\ndef quicksort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr)//2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quicksort(left) + middle + quicksort(right)\n```",
"O(log n)",
"Overfitting occurs when a model learns training data too well, including noise, and performs poorly on new data. It can be addressed by regularization, dropout, or early stopping.",
"```python\ndef fibonacci(n):\n if n <= 1:\n return n\n a, b = 0, 1\n for _ in range(2, n+1):\n a, b = b, a + b\n return b\n```",
"Lists are mutable, tuples are immutable. Lists use square brackets [], tuples use parentheses ().",
"The CAP theorem states that a distributed system cannot simultaneously guarantee Consistency, Availability, and Partition tolerance.",
"```python\ndef merge_sorted(a, b):\n result = []\n i = j = 0\n while i < len(a) and j < len(b):\n if a[i] < b[j]:\n result.append(a[i]); i += 1\n else:\n result.append(b[j]); j += 1\n result.extend(a[i:]); result.extend(b[j:])\n return result\n```",
"a² + b² = c², where a and b are legs and c is the hypotenuse.",
"HTTP is a request-response protocol where a client sends a request to a server, which replies with a status code and body.",
"```python\ndef factorial(n):\n if n <= 1:\n return 1\n return n * factorial(n-1)\n```",
"SQL databases are relational with structured schemas and ACID properties. NoSQL databases are non-relational, schema-flexible, and scale horizontally.",
"Polymorphism allows objects of different types to respond to the same interface. In Python, duck typing lets any object with the required methods be used.",
"105",
"```python\ndef max_subarray_sum(arr):\n max_ending = max_sofar = arr[0]\n for x in arr[1:]:\n max_ending = max(x, max_ending + x)\n max_sofar = max(max_sofar, max_ending)\n return max_sofar\n```",
"DNS converts domain names to IP addresses. The resolver queries root -> TLD -> authoritative servers to find the IP.",
"Stack is LIFO (last-in-first-out) for local variables and function calls. Heap is for dynamically allocated memory with longer lifespan.",
"```python\ndef has_cycle(head):\n slow = fast = head\n while fast and fast.next:\n slow = slow.next\n fast = fast.next.next\n if slow == fast:\n return True\n return False\n```",
"2x",
"Normalization organizes data to reduce redundancy. Forms: 1NF (atomic columns), 2NF (no partial dependency), 3NF (no transitive dependency).",
"```python\ndef binary_search(arr, target):\n lo, hi = 0, len(arr) - 1\n while lo <= hi:\n mid = (lo + hi) // 2\n if arr[mid] == target:\n return mid\n elif arr[mid] < target:\n lo = mid + 1\n else:\n hi = mid - 1\n return -1\n```",
"Symmetric encryption uses the same key for both encryption and decryption. Asymmetric uses a public key to encrypt and a private key to decrypt.",
"A hash table uses a hash function to map keys to array indices, providing O(1) average lookup. Collisions are handled by chaining or open addressing.",
"12",
"```python\ndef sieve(n):\n primes = [True] * (n+1)\n primes[0] = primes[1] = False\n for i in range(2, int(n**0.5)+1):\n if primes[i]:\n for j in range(i*i, n+1, i):\n primes[j] = False\n return [i for i, p in enumerate(primes) if p]\n```",
"REST uses standard HTTP methods with resource-based URLs. GraphQL uses a single endpoint with a query language for flexible data fetching.",
"d/dx f(g(x)) = f'(g(x)) * g'(x)",
"```python\ndef serialize_tree(root):\n def encode(node):\n if not node:\n return 'null'\n return f\"{node.val},{encode(node.left)},{encode(node.right)}\"\n return encode(root)\n```",
"HTTPS adds TLS encryption on top of HTTP, providing confidentiality and integrity.",
"Attention computes weighted combinations of values based on query-key similarity. The Transformer uses multi-head attention to capture different relationship types.",
"```python\ndef edit_distance(a, b):\n m, n = len(a), len(b)\n dp = [[0]*(n+1) for _ in range(m+1)]\n for i in range(m+1): dp[i][0] = i\n for j in range(n+1): dp[0][j] = j\n for i in range(1, m+1):\n for j in range(1, n+1):\n if a[i-1] == b[j-1]:\n dp[i][j] = dp[i-1][j-1]\n else:\n dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])\n return dp[m][n]\n```",
"x² + C",
"A blockchain is a distributed ledger where data is stored in blocks linked by cryptographic hashes. Each block contains a hash of the previous block, forming a chain.",
]
def generate_synthetic(n=50):
random.seed(42)
samples = []
for i in range(n):
sys_idx = i % len(SYSTEM_PROMPTS)
inst_idx = i % len(INSTRUCTIONS)
resp_idx = i % len(RESPONSES)
samples.append({
"system": SYSTEM_PROMPTS[sys_idx],
"instruction": INSTRUCTIONS[inst_idx],
"input": "",
"output": RESPONSES[resp_idx],
})
return samples
def load_and_split(data_path, test_ratio=0.1):
with open(data_path) as f:
data = [json.loads(line) for line in f]
random.seed(42)
random.shuffle(data)
split = int(len(data) * (1 - test_ratio))
return data[:split], data[split:]
def convert_to_messages(data):
out = []
for item in data:
user_msg = item["instruction"]
if item.get("input", ""):
user_msg += "\n" + item["input"]
out.append({
"messages": [
{"role": "system", "content": item.get("system", "")},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": item["output"]},
]
})
return out
def main():
parser = argparse.ArgumentParser(description="Prepare data for Lumia-Tiny")
parser.add_argument("--generate", type=int, default=None,
help="Generate N synthetic examples")
parser.add_argument("--to-jsonl", action="store_true",
help="Convert data/data.jsonl to messages format")
parser.add_argument("--output", default="data/tiny_data.jsonl",
help="Output path")
args = parser.parse_args()
output_path = args.output
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
if args.generate:
samples = generate_synthetic(args.generate)
with open(output_path, "w") as f:
for s in samples:
f.write(json.dumps(s) + "\n")
print(f" Generated {len(samples)} synthetic examples → {output_path}")
return
if args.to_jsonl:
data_path = "data/data.jsonl"
if not os.path.exists(data_path):
print(f" Error: {data_path} not found")
sys.exit(1)
with open(data_path) as f:
data = [json.loads(line) for line in f]
messages = convert_to_messages(data)
base = os.path.splitext(output_path)[0]
with open(f"{base}_messages.jsonl", "w") as f:
for m in messages:
f.write(json.dumps(m) + "\n")
print(f" Converted {len(messages)} examples → {base}_messages.jsonl")
return
# Default: just report stats
data_path = "data/data.jsonl"
if os.path.exists(data_path):
with open(data_path) as f:
lines = f.readlines()
print(f" Data file: {data_path}")
print(f" Samples: {len(lines)}")
print(f" Size: {os.path.getsize(data_path):,} bytes")
sample = json.loads(lines[0])
print(f" Fields: {list(sample.keys())}")
print(f" Keys: system, instruction, input, output")
print(f"")
print(f" To generate synthetic data:")
print(f" python3 scripts/prepare_tiny_data.py --generate 200")
print(f"")
print(f" To convert to messages format:")
print(f" python3 scripts/prepare_tiny_data.py --to-jsonl")
else:
print(f" No data found. Generate with:")
print(f" python3 scripts/prepare_tiny_data.py --generate 200")
if __name__ == "__main__":
main()