""" Chat Formatter for TouchGrass. Formats data into chat format compatible with Qwen3.5 fine-tuning. """ from typing import List, Dict, Any, Optional import json from pathlib import Path class ChatFormatter: """ Formats music QA data into chat format for instruction tuning. Handles: - System prompt injection - Context tags (instrument, skill level, emotion) - Tokenization-ready format - Multi-turn conversations """ def __init__( self, tokenizer=None, max_seq_length: int = 4096, system_prompt: Optional[str] = None, ): """ Initialize chat formatter. Args: tokenizer: Optional tokenizer for length validation max_seq_length: Maximum sequence length system_prompt: Optional custom system prompt """ self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.default_system_prompt = system_prompt or self._get_default_system_prompt() def _get_default_system_prompt(self) -> str: """Get default system prompt.""" return """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant. You help people with: - Learning instruments (guitar, bass, piano, keys, drums, vocals) - Understanding music theory at any level - Writing songs (lyrics, chord progressions, structure) - Ear training and developing musicality - DJ skills and music production - Genre knowledge and music history Your personality: - Patient and encouraging — learning music is hard and takes time - Adapt to the learner's level automatically — simpler for beginners, deeper for advanced - When someone is frustrated, acknowledge it warmly before helping - Use tabs, chord diagrams, and notation when helpful - Make learning fun, not intimidating - Celebrate small wins When generating tabs use this format: [TAB] e|---------| B|---------| G|---------| D|---------| A|---------| E|---------| [/TAB] When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]""" def format_qa_pair( self, question: str, answer: str, context: Optional[str] = None, system_prompt: Optional[str] = None, ) -> Dict[str, Any]: """ Format a single QA pair into chat format. Args: question: User question answer: Assistant answer context: Optional context tags (e.g., "[GUITAR][BEGINNER]") system_prompt: Optional system prompt override Returns: Formatted chat dictionary """ system = system_prompt or self.default_system_prompt # Build user message with context user_message = question if context: user_message = f"{context} {question}".strip() messages = [ {"role": "system", "content": system}, {"role": "user", "content": user_message}, {"role": "assistant", "content": answer}, ] # Validate length if tokenizer provided if self.tokenizer: total_length = self._estimate_length(messages) if total_length > self.max_seq_length: print(f"Warning: Sample exceeds max length ({total_length} > {self.max_seq_length})") # Truncate answer if needed messages = self._truncate_answers(messages) return {"messages": messages} def format_multi_turn( self, conversations: List[Dict[str, str]], system_prompt: Optional[str] = None, ) -> Dict[str, Any]: """ Format multi-turn conversation. Args: conversations: List of {"role": "...", "content": "..."} dicts system_prompt: Optional system prompt Returns: Formatted chat dictionary """ system = system_prompt or self.default_system_prompt # Ensure system is first if conversations[0]["role"] != "system": messages = [{"role": "system", "content": system}] + conversations else: messages = conversations # Validate length if self.tokenizer: total_length = self._estimate_length(messages) if total_length > self.max_seq_length: print(f"Warning: Multi-turn sample exceeds max length ({total_length} > {self.max_seq_length})") messages = self._truncate_multi_turn(messages) return {"messages": messages} def _estimate_length(self, messages: List[Dict[str, str]]) -> int: """Estimate token length of messages.""" if not self.tokenizer: return 0 total = 0 for msg in messages: tokens = self.tokenizer.encode(msg["content"]) total += len(tokens["input_ids"]) return total def _truncate_answers(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Truncate answer to fit max length.""" if not self.tokenizer: return messages system_len = self._estimate_length([messages[0]]) user_len = self._estimate_length([messages[1]]) available = self.max_seq_length - system_len - user_len - 10 # buffer # Truncate answer answer_msg = messages[2].copy() answer_tokens = self.tokenizer.encode(answer_msg["content"]) if len(answer_tokens["input_ids"]) > available: # Truncate and add ellipsis truncated = self.tokenizer.decode(answer_tokens["input_ids"][:available-3]) answer_msg["content"] = truncated + "..." messages[2] = answer_msg return messages def _truncate_multi_turn(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Truncate multi-turn conversation from the end.""" if not self.tokenizer: return messages # Keep system and first few messages, truncate later ones system_msg = messages[0] other_msgs = messages[1:] current_length = self._estimate_length([system_msg]) kept_msgs = [] for msg in other_msgs: msg_len = self._estimate_length([msg]) if current_length + msg_len <= self.max_seq_length - 10: kept_msgs.append(msg) current_length += msg_len else: break return [system_msg] + kept_msgs def save_as_jsonl( self, samples: List[Dict[str, Any]], output_path: str, ): """ Save formatted samples as JSONL. Args: samples: List of formatted samples output_path: Output file path """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: for sample in samples: f.write(json.dumps(sample, ensure_ascii=False) + "\n") print(f"Saved {len(samples)} samples to {output_path}") def load_from_jsonl( self, input_path: str, ) -> List[Dict[str, Any]]: """ Load formatted samples from JSONL. Args: input_path: Input file path Returns: List of samples """ samples = [] with open(input_path, "r", encoding="utf-8") as f: for line in f: samples.append(json.loads(line)) print(f"Loaded {len(samples)} samples from {input_path}") return samples def validate_sample( self, sample: Dict[str, Any], ) -> bool: """ Validate a formatted sample. Args: sample: Sample to validate Returns: True if valid """ if "messages" not in sample: print("Error: Missing 'messages' field") return False messages = sample["messages"] if len(messages) < 2: print("Error: At least 2 messages required (system + user)") return False if messages[0]["role"] != "system": print("Error: First message must be system") return False # Check alternating user/assistant for i in range(1, len(messages), 2): if messages[i]["role"] != "user": print(f"Error: Expected user at position {i}, got {messages[i]['role']}") return False if i + 1 < len(messages) and messages[i + 1]["role"] != "assistant": print(f"Error: Expected assistant at position {i+1}, got {messages[i+1]['role']}") return False return True def create_pretraining_dataset( self, qa_samples: List[Dict[str, Any]], output_dir: str, train_split: float = 0.9, ) -> Dict[str, str]: """ Create train/val splits for fine-tuning. Args: qa_samples: List of QA samples output_dir: Output directory train_split: Train split ratio (0-1) Returns: Dictionary with train/val file paths """ import random random.shuffle(qa_samples) split_idx = int(len(qa_samples) * train_split) train_samples = qa_samples[:split_idx] val_samples = qa_samples[split_idx:] output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) train_path = output_dir / "train.jsonl" val_path = output_dir / "val.jsonl" self.save_as_jsonl(train_samples, str(train_path)) self.save_as_jsonl(val_samples, str(val_path)) print(f"Created splits: train={len(train_samples)}, val={len(val_samples)}") return { "train": str(train_path), "val": str(val_path), } def test_chat_formatter(): """Test the ChatFormatter.""" # Create formatter formatter = ChatFormatter() print("Testing ChatFormatter...\n") # Test QA pair formatting qa = formatter.format_qa_pair( question="How do I play a G chord?", answer="[TAB]...[/TAB] Here's how...", context="[GUITAR][BEGINNER]", ) print("Formatted QA pair:") for msg in qa["messages"]: print(f" {msg['role']}: {msg['content'][:80]}...") # Test validation is_valid = formatter.validate_sample(qa) print(f"\nSample valid: {is_valid}") # Test multi-turn multi_turn = formatter.format_multi_turn([ {"role": "user", "content": "What is a chord?"}, {"role": "assistant", "content": "A chord is..."}, {"role": "user", "content": "Can you give an example?"}, {"role": "assistant", "content": "C major is C-E-G"}, ]) print("\nMulti-turn format:") for msg in multi_turn["messages"]: print(f" {msg['role']}: {msg['content'][:60]}...") print("\nChatFormatter test complete!") if __name__ == "__main__": test_chat_formatter()