TouchGrass-3b / data /chat_formatter.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Chat Formatter for TouchGrass.
Formats data into chat format compatible with Qwen3.5 fine-tuning.
"""
from typing import List, Dict, Any, Optional
import json
from pathlib import Path
class ChatFormatter:
"""
Formats music QA data into chat format for instruction tuning.
Handles:
- System prompt injection
- Context tags (instrument, skill level, emotion)
- Tokenization-ready format
- Multi-turn conversations
"""
def __init__(
self,
tokenizer=None,
max_seq_length: int = 4096,
system_prompt: Optional[str] = None,
):
"""
Initialize chat formatter.
Args:
tokenizer: Optional tokenizer for length validation
max_seq_length: Maximum sequence length
system_prompt: Optional custom system prompt
"""
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.default_system_prompt = system_prompt or self._get_default_system_prompt()
def _get_default_system_prompt(self) -> str:
"""Get default system prompt."""
return """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.
You help people with:
- Learning instruments (guitar, bass, piano, keys, drums, vocals)
- Understanding music theory at any level
- Writing songs (lyrics, chord progressions, structure)
- Ear training and developing musicality
- DJ skills and music production
- Genre knowledge and music history
Your personality:
- Patient and encouraging — learning music is hard and takes time
- Adapt to the learner's level automatically — simpler for beginners, deeper for advanced
- When someone is frustrated, acknowledge it warmly before helping
- Use tabs, chord diagrams, and notation when helpful
- Make learning fun, not intimidating
- Celebrate small wins
When generating tabs use this format:
[TAB]
e|---------|
B|---------|
G|---------|
D|---------|
A|---------|
E|---------|
[/TAB]
When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""
def format_qa_pair(
self,
question: str,
answer: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> Dict[str, Any]:
"""
Format a single QA pair into chat format.
Args:
question: User question
answer: Assistant answer
context: Optional context tags (e.g., "[GUITAR][BEGINNER]")
system_prompt: Optional system prompt override
Returns:
Formatted chat dictionary
"""
system = system_prompt or self.default_system_prompt
# Build user message with context
user_message = question
if context:
user_message = f"{context} {question}".strip()
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer},
]
# Validate length if tokenizer provided
if self.tokenizer:
total_length = self._estimate_length(messages)
if total_length > self.max_seq_length:
print(f"Warning: Sample exceeds max length ({total_length} > {self.max_seq_length})")
# Truncate answer if needed
messages = self._truncate_answers(messages)
return {"messages": messages}
def format_multi_turn(
self,
conversations: List[Dict[str, str]],
system_prompt: Optional[str] = None,
) -> Dict[str, Any]:
"""
Format multi-turn conversation.
Args:
conversations: List of {"role": "...", "content": "..."} dicts
system_prompt: Optional system prompt
Returns:
Formatted chat dictionary
"""
system = system_prompt or self.default_system_prompt
# Ensure system is first
if conversations[0]["role"] != "system":
messages = [{"role": "system", "content": system}] + conversations
else:
messages = conversations
# Validate length
if self.tokenizer:
total_length = self._estimate_length(messages)
if total_length > self.max_seq_length:
print(f"Warning: Multi-turn sample exceeds max length ({total_length} > {self.max_seq_length})")
messages = self._truncate_multi_turn(messages)
return {"messages": messages}
def _estimate_length(self, messages: List[Dict[str, str]]) -> int:
"""Estimate token length of messages."""
if not self.tokenizer:
return 0
total = 0
for msg in messages:
tokens = self.tokenizer.encode(msg["content"])
total += len(tokens["input_ids"])
return total
def _truncate_answers(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Truncate answer to fit max length."""
if not self.tokenizer:
return messages
system_len = self._estimate_length([messages[0]])
user_len = self._estimate_length([messages[1]])
available = self.max_seq_length - system_len - user_len - 10 # buffer
# Truncate answer
answer_msg = messages[2].copy()
answer_tokens = self.tokenizer.encode(answer_msg["content"])
if len(answer_tokens["input_ids"]) > available:
# Truncate and add ellipsis
truncated = self.tokenizer.decode(answer_tokens["input_ids"][:available-3])
answer_msg["content"] = truncated + "..."
messages[2] = answer_msg
return messages
def _truncate_multi_turn(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Truncate multi-turn conversation from the end."""
if not self.tokenizer:
return messages
# Keep system and first few messages, truncate later ones
system_msg = messages[0]
other_msgs = messages[1:]
current_length = self._estimate_length([system_msg])
kept_msgs = []
for msg in other_msgs:
msg_len = self._estimate_length([msg])
if current_length + msg_len <= self.max_seq_length - 10:
kept_msgs.append(msg)
current_length += msg_len
else:
break
return [system_msg] + kept_msgs
def save_as_jsonl(
self,
samples: List[Dict[str, Any]],
output_path: str,
):
"""
Save formatted samples as JSONL.
Args:
samples: List of formatted samples
output_path: Output file path
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples to {output_path}")
def load_from_jsonl(
self,
input_path: str,
) -> List[Dict[str, Any]]:
"""
Load formatted samples from JSONL.
Args:
input_path: Input file path
Returns:
List of samples
"""
samples = []
with open(input_path, "r", encoding="utf-8") as f:
for line in f:
samples.append(json.loads(line))
print(f"Loaded {len(samples)} samples from {input_path}")
return samples
def validate_sample(
self,
sample: Dict[str, Any],
) -> bool:
"""
Validate a formatted sample.
Args:
sample: Sample to validate
Returns:
True if valid
"""
if "messages" not in sample:
print("Error: Missing 'messages' field")
return False
messages = sample["messages"]
if len(messages) < 2:
print("Error: At least 2 messages required (system + user)")
return False
if messages[0]["role"] != "system":
print("Error: First message must be system")
return False
# Check alternating user/assistant
for i in range(1, len(messages), 2):
if messages[i]["role"] != "user":
print(f"Error: Expected user at position {i}, got {messages[i]['role']}")
return False
if i + 1 < len(messages) and messages[i + 1]["role"] != "assistant":
print(f"Error: Expected assistant at position {i+1}, got {messages[i+1]['role']}")
return False
return True
def create_pretraining_dataset(
self,
qa_samples: List[Dict[str, Any]],
output_dir: str,
train_split: float = 0.9,
) -> Dict[str, str]:
"""
Create train/val splits for fine-tuning.
Args:
qa_samples: List of QA samples
output_dir: Output directory
train_split: Train split ratio (0-1)
Returns:
Dictionary with train/val file paths
"""
import random
random.shuffle(qa_samples)
split_idx = int(len(qa_samples) * train_split)
train_samples = qa_samples[:split_idx]
val_samples = qa_samples[split_idx:]
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
train_path = output_dir / "train.jsonl"
val_path = output_dir / "val.jsonl"
self.save_as_jsonl(train_samples, str(train_path))
self.save_as_jsonl(val_samples, str(val_path))
print(f"Created splits: train={len(train_samples)}, val={len(val_samples)}")
return {
"train": str(train_path),
"val": str(val_path),
}
def test_chat_formatter():
"""Test the ChatFormatter."""
# Create formatter
formatter = ChatFormatter()
print("Testing ChatFormatter...\n")
# Test QA pair formatting
qa = formatter.format_qa_pair(
question="How do I play a G chord?",
answer="[TAB]...[/TAB] Here's how...",
context="[GUITAR][BEGINNER]",
)
print("Formatted QA pair:")
for msg in qa["messages"]:
print(f" {msg['role']}: {msg['content'][:80]}...")
# Test validation
is_valid = formatter.validate_sample(qa)
print(f"\nSample valid: {is_valid}")
# Test multi-turn
multi_turn = formatter.format_multi_turn([
{"role": "user", "content": "What is a chord?"},
{"role": "assistant", "content": "A chord is..."},
{"role": "user", "content": "Can you give an example?"},
{"role": "assistant", "content": "C major is C-E-G"},
])
print("\nMulti-turn format:")
for msg in multi_turn["messages"]:
print(f" {msg['role']}: {msg['content'][:60]}...")
print("\nChatFormatter test complete!")
if __name__ == "__main__":
test_chat_formatter()