| | """
|
| | Chat Formatter for TouchGrass.
|
| | Formats data into chat format compatible with Qwen3.5 fine-tuning.
|
| | """
|
| |
|
| | from typing import List, Dict, Any, Optional
|
| | import json
|
| | from pathlib import Path
|
| |
|
| |
|
| | class ChatFormatter:
|
| | """
|
| | Formats music QA data into chat format for instruction tuning.
|
| |
|
| | Handles:
|
| | - System prompt injection
|
| | - Context tags (instrument, skill level, emotion)
|
| | - Tokenization-ready format
|
| | - Multi-turn conversations
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | tokenizer=None,
|
| | max_seq_length: int = 4096,
|
| | system_prompt: Optional[str] = None,
|
| | ):
|
| | """
|
| | Initialize chat formatter.
|
| |
|
| | Args:
|
| | tokenizer: Optional tokenizer for length validation
|
| | max_seq_length: Maximum sequence length
|
| | system_prompt: Optional custom system prompt
|
| | """
|
| | self.tokenizer = tokenizer
|
| | self.max_seq_length = max_seq_length
|
| |
|
| | self.default_system_prompt = system_prompt or self._get_default_system_prompt()
|
| |
|
| | def _get_default_system_prompt(self) -> str:
|
| | """Get default system prompt."""
|
| | return """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.
|
| |
|
| | You help people with:
|
| | - Learning instruments (guitar, bass, piano, keys, drums, vocals)
|
| | - Understanding music theory at any level
|
| | - Writing songs (lyrics, chord progressions, structure)
|
| | - Ear training and developing musicality
|
| | - DJ skills and music production
|
| | - Genre knowledge and music history
|
| |
|
| | Your personality:
|
| | - Patient and encouraging — learning music is hard and takes time
|
| | - Adapt to the learner's level automatically — simpler for beginners, deeper for advanced
|
| | - When someone is frustrated, acknowledge it warmly before helping
|
| | - Use tabs, chord diagrams, and notation when helpful
|
| | - Make learning fun, not intimidating
|
| | - Celebrate small wins
|
| |
|
| | When generating tabs use this format:
|
| | [TAB]
|
| | e|---------|
|
| | B|---------|
|
| | G|---------|
|
| | D|---------|
|
| | A|---------|
|
| | E|---------|
|
| | [/TAB]
|
| |
|
| | When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""
|
| |
|
| | def format_qa_pair(
|
| | self,
|
| | question: str,
|
| | answer: str,
|
| | context: Optional[str] = None,
|
| | system_prompt: Optional[str] = None,
|
| | ) -> Dict[str, Any]:
|
| | """
|
| | Format a single QA pair into chat format.
|
| |
|
| | Args:
|
| | question: User question
|
| | answer: Assistant answer
|
| | context: Optional context tags (e.g., "[GUITAR][BEGINNER]")
|
| | system_prompt: Optional system prompt override
|
| |
|
| | Returns:
|
| | Formatted chat dictionary
|
| | """
|
| | system = system_prompt or self.default_system_prompt
|
| |
|
| |
|
| | user_message = question
|
| | if context:
|
| | user_message = f"{context} {question}".strip()
|
| |
|
| | messages = [
|
| | {"role": "system", "content": system},
|
| | {"role": "user", "content": user_message},
|
| | {"role": "assistant", "content": answer},
|
| | ]
|
| |
|
| |
|
| | if self.tokenizer:
|
| | total_length = self._estimate_length(messages)
|
| | if total_length > self.max_seq_length:
|
| | print(f"Warning: Sample exceeds max length ({total_length} > {self.max_seq_length})")
|
| |
|
| | messages = self._truncate_answers(messages)
|
| |
|
| | return {"messages": messages}
|
| |
|
| | def format_multi_turn(
|
| | self,
|
| | conversations: List[Dict[str, str]],
|
| | system_prompt: Optional[str] = None,
|
| | ) -> Dict[str, Any]:
|
| | """
|
| | Format multi-turn conversation.
|
| |
|
| | Args:
|
| | conversations: List of {"role": "...", "content": "..."} dicts
|
| | system_prompt: Optional system prompt
|
| |
|
| | Returns:
|
| | Formatted chat dictionary
|
| | """
|
| | system = system_prompt or self.default_system_prompt
|
| |
|
| |
|
| | if conversations[0]["role"] != "system":
|
| | messages = [{"role": "system", "content": system}] + conversations
|
| | else:
|
| | messages = conversations
|
| |
|
| |
|
| | if self.tokenizer:
|
| | total_length = self._estimate_length(messages)
|
| | if total_length > self.max_seq_length:
|
| | print(f"Warning: Multi-turn sample exceeds max length ({total_length} > {self.max_seq_length})")
|
| | messages = self._truncate_multi_turn(messages)
|
| |
|
| | return {"messages": messages}
|
| |
|
| | def _estimate_length(self, messages: List[Dict[str, str]]) -> int:
|
| | """Estimate token length of messages."""
|
| | if not self.tokenizer:
|
| | return 0
|
| |
|
| | total = 0
|
| | for msg in messages:
|
| | tokens = self.tokenizer.encode(msg["content"])
|
| | total += len(tokens["input_ids"])
|
| | return total
|
| |
|
| | def _truncate_answers(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| | """Truncate answer to fit max length."""
|
| | if not self.tokenizer:
|
| | return messages
|
| |
|
| | system_len = self._estimate_length([messages[0]])
|
| | user_len = self._estimate_length([messages[1]])
|
| | available = self.max_seq_length - system_len - user_len - 10
|
| |
|
| |
|
| | answer_msg = messages[2].copy()
|
| | answer_tokens = self.tokenizer.encode(answer_msg["content"])
|
| | if len(answer_tokens["input_ids"]) > available:
|
| |
|
| | truncated = self.tokenizer.decode(answer_tokens["input_ids"][:available-3])
|
| | answer_msg["content"] = truncated + "..."
|
| | messages[2] = answer_msg
|
| |
|
| | return messages
|
| |
|
| | def _truncate_multi_turn(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| | """Truncate multi-turn conversation from the end."""
|
| | if not self.tokenizer:
|
| | return messages
|
| |
|
| |
|
| | system_msg = messages[0]
|
| | other_msgs = messages[1:]
|
| |
|
| | current_length = self._estimate_length([system_msg])
|
| | kept_msgs = []
|
| |
|
| | for msg in other_msgs:
|
| | msg_len = self._estimate_length([msg])
|
| | if current_length + msg_len <= self.max_seq_length - 10:
|
| | kept_msgs.append(msg)
|
| | current_length += msg_len
|
| | else:
|
| | break
|
| |
|
| | return [system_msg] + kept_msgs
|
| |
|
| | def save_as_jsonl(
|
| | self,
|
| | samples: List[Dict[str, Any]],
|
| | output_path: str,
|
| | ):
|
| | """
|
| | Save formatted samples as JSONL.
|
| |
|
| | Args:
|
| | samples: List of formatted samples
|
| | output_path: Output file path
|
| | """
|
| | output_path = Path(output_path)
|
| | output_path.parent.mkdir(parents=True, exist_ok=True)
|
| |
|
| | with open(output_path, "w", encoding="utf-8") as f:
|
| | for sample in samples:
|
| | f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
| |
|
| | print(f"Saved {len(samples)} samples to {output_path}")
|
| |
|
| | def load_from_jsonl(
|
| | self,
|
| | input_path: str,
|
| | ) -> List[Dict[str, Any]]:
|
| | """
|
| | Load formatted samples from JSONL.
|
| |
|
| | Args:
|
| | input_path: Input file path
|
| |
|
| | Returns:
|
| | List of samples
|
| | """
|
| | samples = []
|
| | with open(input_path, "r", encoding="utf-8") as f:
|
| | for line in f:
|
| | samples.append(json.loads(line))
|
| |
|
| | print(f"Loaded {len(samples)} samples from {input_path}")
|
| | return samples
|
| |
|
| | def validate_sample(
|
| | self,
|
| | sample: Dict[str, Any],
|
| | ) -> bool:
|
| | """
|
| | Validate a formatted sample.
|
| |
|
| | Args:
|
| | sample: Sample to validate
|
| |
|
| | Returns:
|
| | True if valid
|
| | """
|
| | if "messages" not in sample:
|
| | print("Error: Missing 'messages' field")
|
| | return False
|
| |
|
| | messages = sample["messages"]
|
| | if len(messages) < 2:
|
| | print("Error: At least 2 messages required (system + user)")
|
| | return False
|
| |
|
| | if messages[0]["role"] != "system":
|
| | print("Error: First message must be system")
|
| | return False
|
| |
|
| |
|
| | for i in range(1, len(messages), 2):
|
| | if messages[i]["role"] != "user":
|
| | print(f"Error: Expected user at position {i}, got {messages[i]['role']}")
|
| | return False
|
| | if i + 1 < len(messages) and messages[i + 1]["role"] != "assistant":
|
| | print(f"Error: Expected assistant at position {i+1}, got {messages[i+1]['role']}")
|
| | return False
|
| |
|
| | return True
|
| |
|
| | def create_pretraining_dataset(
|
| | self,
|
| | qa_samples: List[Dict[str, Any]],
|
| | output_dir: str,
|
| | train_split: float = 0.9,
|
| | ) -> Dict[str, str]:
|
| | """
|
| | Create train/val splits for fine-tuning.
|
| |
|
| | Args:
|
| | qa_samples: List of QA samples
|
| | output_dir: Output directory
|
| | train_split: Train split ratio (0-1)
|
| |
|
| | Returns:
|
| | Dictionary with train/val file paths
|
| | """
|
| | import random
|
| | random.shuffle(qa_samples)
|
| |
|
| | split_idx = int(len(qa_samples) * train_split)
|
| | train_samples = qa_samples[:split_idx]
|
| | val_samples = qa_samples[split_idx:]
|
| |
|
| | output_dir = Path(output_dir)
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| | train_path = output_dir / "train.jsonl"
|
| | val_path = output_dir / "val.jsonl"
|
| |
|
| | self.save_as_jsonl(train_samples, str(train_path))
|
| | self.save_as_jsonl(val_samples, str(val_path))
|
| |
|
| | print(f"Created splits: train={len(train_samples)}, val={len(val_samples)}")
|
| |
|
| | return {
|
| | "train": str(train_path),
|
| | "val": str(val_path),
|
| | }
|
| |
|
| |
|
| | def test_chat_formatter():
|
| | """Test the ChatFormatter."""
|
| |
|
| | formatter = ChatFormatter()
|
| |
|
| | print("Testing ChatFormatter...\n")
|
| |
|
| |
|
| | qa = formatter.format_qa_pair(
|
| | question="How do I play a G chord?",
|
| | answer="[TAB]...[/TAB] Here's how...",
|
| | context="[GUITAR][BEGINNER]",
|
| | )
|
| |
|
| | print("Formatted QA pair:")
|
| | for msg in qa["messages"]:
|
| | print(f" {msg['role']}: {msg['content'][:80]}...")
|
| |
|
| |
|
| | is_valid = formatter.validate_sample(qa)
|
| | print(f"\nSample valid: {is_valid}")
|
| |
|
| |
|
| | multi_turn = formatter.format_multi_turn([
|
| | {"role": "user", "content": "What is a chord?"},
|
| | {"role": "assistant", "content": "A chord is..."},
|
| | {"role": "user", "content": "Can you give an example?"},
|
| | {"role": "assistant", "content": "C major is C-E-G"},
|
| | ])
|
| |
|
| | print("\nMulti-turn format:")
|
| | for msg in multi_turn["messages"]:
|
| | print(f" {msg['role']}: {msg['content'][:60]}...")
|
| |
|
| | print("\nChatFormatter test complete!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_chat_formatter() |