File size: 11,375 Bytes
9071ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""

Chat Formatter for TouchGrass.

Formats data into chat format compatible with Qwen3.5 fine-tuning.

"""

from typing import List, Dict, Any, Optional
import json
from pathlib import Path


class ChatFormatter:
    """

    Formats music QA data into chat format for instruction tuning.

    

    Handles:

    - System prompt injection

    - Context tags (instrument, skill level, emotion)

    - Tokenization-ready format

    - Multi-turn conversations

    """

    def __init__(

        self,

        tokenizer=None,

        max_seq_length: int = 4096,

        system_prompt: Optional[str] = None,

    ):
        """

        Initialize chat formatter.



        Args:

            tokenizer: Optional tokenizer for length validation

            max_seq_length: Maximum sequence length

            system_prompt: Optional custom system prompt

        """
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        self.default_system_prompt = system_prompt or self._get_default_system_prompt()

    def _get_default_system_prompt(self) -> str:
        """Get default system prompt."""
        return """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.



You help people with:

- Learning instruments (guitar, bass, piano, keys, drums, vocals)

- Understanding music theory at any level

- Writing songs (lyrics, chord progressions, structure)

- Ear training and developing musicality

- DJ skills and music production

- Genre knowledge and music history



Your personality:

- Patient and encouraging — learning music is hard and takes time

- Adapt to the learner's level automatically — simpler for beginners, deeper for advanced

- When someone is frustrated, acknowledge it warmly before helping

- Use tabs, chord diagrams, and notation when helpful

- Make learning fun, not intimidating

- Celebrate small wins



When generating tabs use this format:

[TAB]

e|---------|

B|---------|

G|---------|

D|---------|

A|---------|

E|---------|

[/TAB]



When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""

    def format_qa_pair(

        self,

        question: str,

        answer: str,

        context: Optional[str] = None,

        system_prompt: Optional[str] = None,

    ) -> Dict[str, Any]:
        """

        Format a single QA pair into chat format.



        Args:

            question: User question

            answer: Assistant answer

            context: Optional context tags (e.g., "[GUITAR][BEGINNER]")

            system_prompt: Optional system prompt override



        Returns:

            Formatted chat dictionary

        """
        system = system_prompt or self.default_system_prompt

        # Build user message with context
        user_message = question
        if context:
            user_message = f"{context} {question}".strip()

        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": answer},
        ]

        # Validate length if tokenizer provided
        if self.tokenizer:
            total_length = self._estimate_length(messages)
            if total_length > self.max_seq_length:
                print(f"Warning: Sample exceeds max length ({total_length} > {self.max_seq_length})")
                # Truncate answer if needed
                messages = self._truncate_answers(messages)

        return {"messages": messages}

    def format_multi_turn(

        self,

        conversations: List[Dict[str, str]],

        system_prompt: Optional[str] = None,

    ) -> Dict[str, Any]:
        """

        Format multi-turn conversation.



        Args:

            conversations: List of {"role": "...", "content": "..."} dicts

            system_prompt: Optional system prompt



        Returns:

            Formatted chat dictionary

        """
        system = system_prompt or self.default_system_prompt

        # Ensure system is first
        if conversations[0]["role"] != "system":
            messages = [{"role": "system", "content": system}] + conversations
        else:
            messages = conversations

        # Validate length
        if self.tokenizer:
            total_length = self._estimate_length(messages)
            if total_length > self.max_seq_length:
                print(f"Warning: Multi-turn sample exceeds max length ({total_length} > {self.max_seq_length})")
                messages = self._truncate_multi_turn(messages)

        return {"messages": messages}

    def _estimate_length(self, messages: List[Dict[str, str]]) -> int:
        """Estimate token length of messages."""
        if not self.tokenizer:
            return 0

        total = 0
        for msg in messages:
            tokens = self.tokenizer.encode(msg["content"])
            total += len(tokens["input_ids"])
        return total

    def _truncate_answers(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
        """Truncate answer to fit max length."""
        if not self.tokenizer:
            return messages

        system_len = self._estimate_length([messages[0]])
        user_len = self._estimate_length([messages[1]])
        available = self.max_seq_length - system_len - user_len - 10  # buffer

        # Truncate answer
        answer_msg = messages[2].copy()
        answer_tokens = self.tokenizer.encode(answer_msg["content"])
        if len(answer_tokens["input_ids"]) > available:
            # Truncate and add ellipsis
            truncated = self.tokenizer.decode(answer_tokens["input_ids"][:available-3])
            answer_msg["content"] = truncated + "..."
            messages[2] = answer_msg

        return messages

    def _truncate_multi_turn(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
        """Truncate multi-turn conversation from the end."""
        if not self.tokenizer:
            return messages

        # Keep system and first few messages, truncate later ones
        system_msg = messages[0]
        other_msgs = messages[1:]

        current_length = self._estimate_length([system_msg])
        kept_msgs = []

        for msg in other_msgs:
            msg_len = self._estimate_length([msg])
            if current_length + msg_len <= self.max_seq_length - 10:
                kept_msgs.append(msg)
                current_length += msg_len
            else:
                break

        return [system_msg] + kept_msgs

    def save_as_jsonl(

        self,

        samples: List[Dict[str, Any]],

        output_path: str,

    ):
        """

        Save formatted samples as JSONL.



        Args:

            samples: List of formatted samples

            output_path: Output file path

        """
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        with open(output_path, "w", encoding="utf-8") as f:
            for sample in samples:
                f.write(json.dumps(sample, ensure_ascii=False) + "\n")

        print(f"Saved {len(samples)} samples to {output_path}")

    def load_from_jsonl(

        self,

        input_path: str,

    ) -> List[Dict[str, Any]]:
        """

        Load formatted samples from JSONL.



        Args:

            input_path: Input file path



        Returns:

            List of samples

        """
        samples = []
        with open(input_path, "r", encoding="utf-8") as f:
            for line in f:
                samples.append(json.loads(line))

        print(f"Loaded {len(samples)} samples from {input_path}")
        return samples

    def validate_sample(

        self,

        sample: Dict[str, Any],

    ) -> bool:
        """

        Validate a formatted sample.



        Args:

            sample: Sample to validate



        Returns:

            True if valid

        """
        if "messages" not in sample:
            print("Error: Missing 'messages' field")
            return False

        messages = sample["messages"]
        if len(messages) < 2:
            print("Error: At least 2 messages required (system + user)")
            return False

        if messages[0]["role"] != "system":
            print("Error: First message must be system")
            return False

        # Check alternating user/assistant
        for i in range(1, len(messages), 2):
            if messages[i]["role"] != "user":
                print(f"Error: Expected user at position {i}, got {messages[i]['role']}")
                return False
            if i + 1 < len(messages) and messages[i + 1]["role"] != "assistant":
                print(f"Error: Expected assistant at position {i+1}, got {messages[i+1]['role']}")
                return False

        return True

    def create_pretraining_dataset(

        self,

        qa_samples: List[Dict[str, Any]],

        output_dir: str,

        train_split: float = 0.9,

    ) -> Dict[str, str]:
        """

        Create train/val splits for fine-tuning.



        Args:

            qa_samples: List of QA samples

            output_dir: Output directory

            train_split: Train split ratio (0-1)



        Returns:

            Dictionary with train/val file paths

        """
        import random
        random.shuffle(qa_samples)

        split_idx = int(len(qa_samples) * train_split)
        train_samples = qa_samples[:split_idx]
        val_samples = qa_samples[split_idx:]

        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        train_path = output_dir / "train.jsonl"
        val_path = output_dir / "val.jsonl"

        self.save_as_jsonl(train_samples, str(train_path))
        self.save_as_jsonl(val_samples, str(val_path))

        print(f"Created splits: train={len(train_samples)}, val={len(val_samples)}")

        return {
            "train": str(train_path),
            "val": str(val_path),
        }


def test_chat_formatter():
    """Test the ChatFormatter."""
    # Create formatter
    formatter = ChatFormatter()

    print("Testing ChatFormatter...\n")

    # Test QA pair formatting
    qa = formatter.format_qa_pair(
        question="How do I play a G chord?",
        answer="[TAB]...[/TAB] Here's how...",
        context="[GUITAR][BEGINNER]",
    )

    print("Formatted QA pair:")
    for msg in qa["messages"]:
        print(f"  {msg['role']}: {msg['content'][:80]}...")

    # Test validation
    is_valid = formatter.validate_sample(qa)
    print(f"\nSample valid: {is_valid}")

    # Test multi-turn
    multi_turn = formatter.format_multi_turn([
        {"role": "user", "content": "What is a chord?"},
        {"role": "assistant", "content": "A chord is..."},
        {"role": "user", "content": "Can you give an example?"},
        {"role": "assistant", "content": "C major is C-E-G"},
    ])

    print("\nMulti-turn format:")
    for msg in multi_turn["messages"]:
        print(f"  {msg['role']}: {msg['content'][:60]}...")

    print("\nChatFormatter test complete!")


if __name__ == "__main__":
    test_chat_formatter()