File size: 4,029 Bytes
aa9be1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Generate multi-turn dialogue data, each turn contains lyric text and corresponding audio token slices
"""

import json
import re
import os
import torch
from tqdm import tqdm
from my_tool import load_jsonl

TOKEN_PER_SECOND = 25  # Number of tokens per second of audio
NUM_ITEMS = 100000  # Process N items first

timestamp_pattern = re.compile(r"\[([0-9]{1,2}):([0-9]{1,2})(?:[.:]([0-9]{1,3}))?\]")

def _parse_lyric_with_timestamps(lyric: str):
    """
    Return [(start_time_s, text), ...] sorted by timestamp
    """
    result = []
    for match in timestamp_pattern.finditer(lyric):
        start_idx = match.end()
        end_idx = lyric.find("[", start_idx)
        text = lyric[start_idx:end_idx].strip() if end_idx != -1 else lyric[start_idx:].strip()
        if not text:
            continue
        minute = int(match.group(1))
        second = int(match.group(2))
        ms = int(match.group(3)) if match.group(3) else 0
        total_seconds = minute * 60 + second + ms / 1000
        result.append((total_seconds, text))
    return result

def _load_audio_tokens(pt_file):
    """
    Load MuCodec encoding of audio
    """
    audio_ids = torch.load(pt_file, map_location="cpu").squeeze().long()
    return audio_ids

def _get_token_slice(audio_tokens, start_s, end_s):
    """Split encoding by time segment"""
    start_idx = int(start_s * TOKEN_PER_SECOND)
    end_idx = int(end_s * TOKEN_PER_SECOND)
    sliced = audio_tokens[start_idx:end_idx]
    return "[SOA]" + "".join([f"<AUDIO_{i.item()}>" for i in sliced]) + "[EOA]"

def _process_item(item, pt_dir:str):
    song_name = item.get("song") or item.get("name")
    song_name = song_name.split('.mp3')[0]              # For mucodec, remove extension
    pt_file = os.path.join(pt_dir, f"{song_name}.pt")
    if not os.path.exists(pt_file):
        return None

    audio_tokens = _load_audio_tokens(pt_file)
    tlyric_ = item.get('tlyric', "")
    lyric_ = item.get('lyric', "")
    lyric = tlyric_ if len(tlyric_) > len(lyric_) else lyric_
    lyrics_ts = _parse_lyric_with_timestamps(lyric)

    if not lyrics_ts:
        # Skip if no lyrics
        return None

    rounds = []

    # First generate a system message containing song information
    intro_text = (
        f"่ฏท็”Ÿๆˆไธ€้ฆ–ๆญŒๆ›ฒ๏ผŒๆญŒๅไธบใ€Š{item.get('name', '')}ใ€‹๏ผŒ้ฃŽๆ ผๆ˜ฏ{item.get('style','')}"
        f"๏ผŒๆƒ…็ปชไธบ{item.get('emotion','')}๏ผŒ่Š‚ๅฅ๏ผš{item.get('rhythm','')}๏ผŒ"
        f"{item.get('description','')}๏ผŒ็”ฑ{item.get('singer','')}ๆผ”ๅ”ฑ๏ผŒ่ฏญ่จ€๏ผš{item.get('lang','')}ใ€‚"
        f"ๆญŒ่ฏๅฆ‚ไธ‹๏ผš" + " ".join([text for _, text in lyrics_ts]) + "ๆŽฅไธ‹ๆฅๆˆ‘ไผš้€ๅฅๅ‘Š่ฏ‰ไฝ ้œ€่ฆ็”ŸๆˆๆญŒๆ›ฒ็‰‡ๆฎต็š„ๆญŒ่ฏ๏ผŒ\n่ฏทๅ…ˆ็”Ÿๆˆๅ‰ๅฅ"
    )
    rounds.append({"role": "user", "content": intro_text})
    rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, 0, lyrics_ts[0][0])})  # Intro tokens

    # Each lyric line corresponds to one round
    for idx, (start_s, text) in enumerate(lyrics_ts[:-1]):  ## Last line handled separately
        end_s = lyrics_ts[idx + 1][0] if idx + 1 < len(lyrics_ts) else len(audio_tokens)/TOKEN_PER_SECOND  # Last line to end of audio
        rounds.append({"role": "user", "content": text})
        rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, start_s, end_s)})

    # Tail processing logic
    rounds.append({"role": "user", "content": f"่ฏท็”ŸๆˆๆญŒ่ฏ{lyrics_ts[-1][1]}ไปฅๅŠๆญŒๆ›ฒ็ป“ๅฐพ"})
    rounds.append({"role": "assistant", "content": _get_token_slice(audio_tokens, lyrics_ts[-1][0], len(audio_tokens)/TOKEN_PER_SECOND)})

    return rounds

# ===== External Interface =====

def get_convert_convs(dataset:list[dict], pt_dir:str, save_path:str):
    with open(save_path, "w", encoding="utf-8") as fout:
        for item in tqdm(dataset, desc="Converting convs"):
            rounds = _process_item(item, pt_dir)
            if not rounds:
                continue
            fout.write(json.dumps({"messages": rounds}, ensure_ascii=False) + "\n")