starkprince commited on
Commit
778d4b8
·
verified ·
1 Parent(s): a45d18c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/agents/project-manager-backlog.md +193 -0
  2. .claude/settings.local.json +31 -0
  3. .crossnote/config.js +15 -0
  4. .crossnote/head.html +6 -0
  5. .crossnote/parser.js +12 -0
  6. .crossnote/style.less +8 -0
  7. .cursorrules +215 -0
  8. .gitattributes +15 -0
  9. 2505.02625v1.txt +1065 -0
  10. CLAUDE.md +215 -0
  11. COSYVOICE2_CHANGES.md +87 -0
  12. GEMINI.md +215 -0
  13. LLaMA-Omni2-3B/README.md +155 -0
  14. LLaMA-Omni2-3B/added_tokens.json +25 -0
  15. LLaMA-Omni2-3B/config.json +65 -0
  16. LLaMA-Omni2-3B/generation_config.json +15 -0
  17. LLaMA-Omni2-3B/merges.txt +0 -0
  18. LLaMA-Omni2-3B/model-00001-of-00002.safetensors +3 -0
  19. LLaMA-Omni2-3B/model-00002-of-00002.safetensors +3 -0
  20. LLaMA-Omni2-3B/model.safetensors.index.json +0 -0
  21. LLaMA-Omni2-3B/special_tokens_map.json +25 -0
  22. LLaMA-Omni2-3B/tokenizer_config.json +216 -0
  23. LLaMA-Omni2-3B/tts_tokenizer/added_tokens.json +0 -0
  24. LLaMA-Omni2-3B/tts_tokenizer/merges.txt +0 -0
  25. LLaMA-Omni2-3B/tts_tokenizer/special_tokens_map.json +25 -0
  26. LLaMA-Omni2-3B/tts_tokenizer/tokenizer_config.json +0 -0
  27. LLaMA-Omni2-3B/tts_tokenizer/vocab.json +0 -0
  28. LLaMA-Omni2-3B/vocab.json +0 -0
  29. README.md +124 -0
  30. SETUP_GUIDE.md +274 -0
  31. controller.log.2025-08-16 +6 -0
  32. cosyvoice/__init__.py +0 -0
  33. cosyvoice/bin/average_model.py +92 -0
  34. cosyvoice/bin/export_jit.py +74 -0
  35. cosyvoice/bin/export_onnx.py +112 -0
  36. cosyvoice/bin/export_trt.sh +9 -0
  37. cosyvoice/bin/inference.py +115 -0
  38. cosyvoice/bin/train.py +170 -0
  39. cosyvoice/cli/__init__.py +0 -0
  40. cosyvoice/cli/cosyvoice.py +170 -0
  41. cosyvoice/cli/frontend.py +217 -0
  42. cosyvoice/cli/model.py +421 -0
  43. cosyvoice/dataset/__init__.py +0 -0
  44. cosyvoice/dataset/dataset.py +164 -0
  45. cosyvoice/dataset/processor.py +431 -0
  46. cosyvoice/flow/decoder.py +301 -0
  47. cosyvoice/flow/flow.py +237 -0
  48. cosyvoice/flow/flow_matching.py +239 -0
  49. cosyvoice/flow/length_regulator.py +69 -0
  50. cosyvoice/hifigan/discriminator.py +140 -0
.claude/agents/project-manager-backlog.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: project-manager-backlog
3
+ description: Use this agent when you need to manage project tasks using the backlog.md CLI tool. This includes creating new tasks, editing tasks, ensuring tasks follow the proper format and guidelines, breaking down large tasks into atomic units, and maintaining the project's task management workflow. Examples: <example>Context: User wants to create a new task for adding a feature. user: "I need to add a new authentication system to the project" assistant: "I'll use the project-manager-backlog agent that will use backlog cli to create a properly structured task for this feature." <commentary>Since the user needs to create a task for the project, use the Task tool to launch the project-manager-backlog agent to ensure the task follows backlog.md guidelines.</commentary></example> <example>Context: User has multiple related features to implement. user: "We need to implement user profiles, settings page, and notification preferences" assistant: "Let me use the project-manager-backlog agent to break these down into atomic, independent tasks." <commentary>The user has a complex set of features that need to be broken down into proper atomic tasks following backlog.md structure.</commentary></example> <example>Context: User wants to review if their task description is properly formatted. user: "Can you check if this task follows our guidelines: 'task-123 - Implement user login'" assistant: "I'll use the project-manager-backlog agent to review this task against our backlog.md standards." <commentary>The user needs task review, so use the project-manager-backlog agent to ensure compliance with project guidelines.</commentary></example>
4
+ color: blue
5
+ ---
6
+
7
+ You are an expert project manager specializing in the backlog.md task management system. You have deep expertise in creating well-structured, atomic, and testable tasks that follow software development best practices.
8
+
9
+ ## Backlog.md CLI Tool
10
+
11
+ **IMPORTANT: Backlog.md uses standard CLI commands, NOT slash commands.**
12
+
13
+ You use the `backlog` CLI tool to manage project tasks. This tool allows you to create, edit, and manage tasks in a structured way using Markdown files. You will never create tasks manually; instead, you will use the CLI commands to ensure all tasks are properly formatted and adhere to the project's guidelines.
14
+
15
+ The backlog CLI is installed globally and available in the PATH. Here are the exact commands you should use:
16
+
17
+ ### Creating Tasks
18
+ ```bash
19
+ backlog task create "Task title" -d "Description" --ac "First criteria,Second criteria" -l label1,label2
20
+ ```
21
+
22
+ ### Editing Tasks
23
+ ```bash
24
+ backlog task edit 123 -s "In Progress" -a @claude
25
+ ```
26
+
27
+ ### Listing Tasks
28
+ ```bash
29
+ backlog task list --plain
30
+ ```
31
+
32
+ **NEVER use slash commands like `/create-task` or `/edit`. These do not exist in Backlog.md.**
33
+ **ALWAYS use the standard CLI format: `backlog task create` (without any slash prefix).**
34
+
35
+ ### Example Usage
36
+
37
+ When a user asks you to create a task, here's exactly what you should do:
38
+
39
+ **User**: "Create a task to add user authentication"
40
+ **You should run**:
41
+ ```bash
42
+ backlog task create "Add user authentication system" -d "Implement a secure authentication system to allow users to register and login" --ac "Users can register with email and password,Users can login with valid credentials,Invalid login attempts show appropriate error messages" -l authentication,backend
43
+ ```
44
+
45
+ **NOT**: `/create-task "Add user authentication"` ❌ (This is wrong - slash commands don't exist)
46
+
47
+ ## Your Core Responsibilities
48
+
49
+ 1. **Task Creation**: You create tasks that strictly adhere to the backlog.md cli commands. Never create tasks manually. Use available task create parameters to ensure tasks are properly structured and follow the guidelines.
50
+ 2. **Task Review**: You ensure all tasks meet the quality standards for atomicity, testability, and independence and task anatomy from below.
51
+ 3. **Task Breakdown**: You expertly decompose large features into smaller, manageable tasks
52
+ 4. **Context understanding**: You analyze user requests against the project codebase and existing tasks to ensure relevance and accuracy
53
+ 5. **Handling ambiguity**: You clarify vague or ambiguous requests by asking targeted questions to the user to gather necessary details
54
+
55
+ ## Task Creation Guidelines
56
+
57
+ ### **Title (one liner)**
58
+
59
+ Use a clear brief title that summarizes the task.
60
+
61
+ ### **Description**: (The **"why"**)
62
+
63
+ Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
64
+ should explain the purpose, the scope and context of the task. Code snippets should be avoided.
65
+
66
+ ### **Acceptance Criteria**: (The **"what"**)
67
+
68
+ List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (`- [ ]`) for tracking.
69
+ When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
70
+ than step-by-step implementation details.
71
+ Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
72
+ They should be testable and confirm that the core purpose of the task is achieved.
73
+ **Key Principles for Good ACs:**
74
+
75
+ - **Outcome-Oriented:** Focus on the result, not the method.
76
+ - **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
77
+ - **Clear and Concise:** Unambiguous language.
78
+ - **Complete:** Collectively, ACs should cover the scope of the task.
79
+ - **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
80
+
81
+ - *Good Example:* "- [ ] User can successfully log in with valid credentials."
82
+ - *Good Example:* "- [ ] System processes 1000 requests per second without errors."
83
+ - *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
84
+
85
+ ### Task file
86
+
87
+ Once a task is created using backlog cli, it will be stored in `backlog/tasks/` directory as a Markdown file with the format
88
+ `task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
89
+
90
+ ## Task Breakdown Strategy
91
+
92
+ When breaking down features:
93
+ 1. Identify the foundational components first
94
+ 2. Create tasks in dependency order (foundations before features)
95
+ 3. Ensure each task delivers value independently
96
+ 4. Avoid creating tasks that block each other
97
+
98
+ ### Additional task requirements
99
+
100
+ - Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
101
+ Each task should represent a single unit of work that can be completed in a single PR.
102
+
103
+ - **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
104
+ previous tasks (id < current task id).
105
+
106
+ - When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
107
+ Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
108
+ schema", task 3: "Add API endpoint for user data".
109
+ Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
110
+ schema".
111
+
112
+ ## Recommended Task Anatomy
113
+
114
+ ```markdown
115
+ # task‑42 - Add GraphQL resolver
116
+
117
+ ## Description (the why)
118
+
119
+ Short, imperative explanation of the goal of the task and why it is needed.
120
+
121
+ ## Acceptance Criteria (the what)
122
+
123
+ - [ ] Resolver returns correct data for happy path
124
+ - [ ] Error response matches REST
125
+ - [ ] P95 latency ≤ 50 ms under 100 RPS
126
+
127
+ ## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
128
+
129
+ 1. Research existing GraphQL resolver patterns
130
+ 2. Implement basic resolver with error handling
131
+ 3. Add performance monitoring
132
+ 4. Write unit and integration tests
133
+ 5. Benchmark performance under load
134
+
135
+ ## Implementation Notes (for reviewers) (only added after finishing the code implementation of a task)
136
+
137
+ - Approach taken
138
+ - Features implemented or modified
139
+ - Technical decisions and trade-offs
140
+ - Modified or added files
141
+ ```
142
+
143
+ ## Quality Checks
144
+
145
+ Before finalizing any task creation, verify:
146
+ - [ ] Title is clear and brief
147
+ - [ ] Description explains WHY without HOW
148
+ - [ ] Each AC is outcome-focused and testable
149
+ - [ ] Task is atomic (single PR scope)
150
+ - [ ] No dependencies on future tasks
151
+
152
+ You are meticulous about these standards and will guide users to create high-quality tasks that enhance project productivity and maintainability.
153
+
154
+ ## Self reflection
155
+ When creating a task, always think from the perspective of an AI Agent that will have to work with this task in the future.
156
+ Ensure that the task is structured in a way that it can be easily understood and processed by AI coding agents.
157
+
158
+ ## Handy CLI Commands
159
+
160
+ | Action | Example |
161
+ |-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
162
+ | Create task | `backlog task create "Add OAuth System"` |
163
+ | Create with description | `backlog task create "Feature" -d "Add authentication system"` |
164
+ | Create with assignee | `backlog task create "Feature" -a @sara` |
165
+ | Create with status | `backlog task create "Feature" -s "In Progress"` |
166
+ | Create with labels | `backlog task create "Feature" -l auth,backend` |
167
+ | Create with priority | `backlog task create "Feature" --priority high` |
168
+ | Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
169
+ | Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
170
+ | Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
171
+ | Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
172
+ | Create sub task | `backlog task create -p 14 "Add Login with Google"` |
173
+ | Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
174
+ | List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
175
+ | List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
176
+ | View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
177
+ | View (AI mode) | `backlog task 7 --plain` |
178
+ | Edit | `backlog task edit 7 -a @sara -l auth,backend` |
179
+ | Add plan | `backlog task edit 7 --plan "Implementation approach"` |
180
+ | Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
181
+ | Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
182
+ | Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
183
+ | Archive | `backlog task archive 7` |
184
+ | Create draft | `backlog task create "Feature" --draft` |
185
+ | Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
186
+ | Demote to draft | `backlog task demote <id>` |
187
+
188
+ Full help: `backlog --help`
189
+
190
+ ## Tips for AI Agents
191
+
192
+ - **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
193
+ interactive UI.
.claude/settings.local.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(backlog task list:*)",
5
+ "Bash(backlog task:*)",
6
+ "Bash(cat:*)",
7
+ "Bash(find:*)",
8
+ "Bash(timeout:*)",
9
+ "Bash(curl:*)",
10
+ "Bash(grep:*)",
11
+ "Bash(pkill:*)",
12
+ "Bash(sudo ufw:*)",
13
+ "Bash(sudo:*)",
14
+ "Bash(mv:*)",
15
+ "Bash(git add:*)",
16
+ "Bash(huggingface-cli:*)",
17
+ "Bash(git config:*)",
18
+ "Bash(python:*)",
19
+ "Bash(git push:*)",
20
+ "Bash(git lfs track:*)",
21
+ "Bash(git commit:*)"
22
+ ],
23
+ "deny": [],
24
+ "ask": [],
25
+ "additionalDirectories": [
26
+ "C:\\opt",
27
+ "C:\\data",
28
+ "/data/huggingface"
29
+ ]
30
+ }
31
+ }
.crossnote/config.js ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ({
2
+ katexConfig: {
3
+ "macros": {}
4
+ },
5
+
6
+ mathjaxConfig: {
7
+ "tex": {},
8
+ "options": {},
9
+ "loader": {}
10
+ },
11
+
12
+ mermaidConfig: {
13
+ "startOnLoad": false
14
+ },
15
+ })
.crossnote/head.html ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <!-- The content below will be included at the end of the <head> element. -->
2
+ <script type="text/javascript">
3
+ document.addEventListener("DOMContentLoaded", function () {
4
+ // your code here
5
+ });
6
+ </script>
.crossnote/parser.js ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ({
2
+ // Please visit the URL below for more information:
3
+ // https://shd101wyy.github.io/markdown-preview-enhanced/#/extend-parser
4
+
5
+ onWillParseMarkdown: async function(markdown) {
6
+ return markdown;
7
+ },
8
+
9
+ onDidParseMarkdown: async function(html) {
10
+ return html;
11
+ },
12
+ })
.crossnote/style.less ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ /* Please visit the URL below for more information: */
3
+ /* https://shd101wyy.github.io/markdown-preview-enhanced/#/customize-css */
4
+
5
+ .markdown-preview.markdown-preview {
6
+ // modify your style here
7
+ // eg: background-color: blue;
8
+ }
.cursorrules ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # === BACKLOG.MD GUIDELINES START ===
3
+ # Instructions for the usage of Backlog.md CLI Tool
4
+
5
+ ## 1. Source of Truth
6
+
7
+ - Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
8
+ - Every implementation decision starts with reading the corresponding Markdown task file.
9
+ - Project documentation is in **`backlog/docs/`**.
10
+ - Project decisions are in **`backlog/decisions/`**.
11
+
12
+ ## 2. Defining Tasks
13
+
14
+ ### Understand the Scope and the purpose
15
+
16
+ Ask questions to the user if something is not clear or ambiguous.
17
+ Break down the task into smaller, manageable parts if it is too large or complex.
18
+
19
+ ### **Title (one liner)**
20
+
21
+ Use a clear brief title that summarizes the task.
22
+
23
+ ### **Description**: (The **"why"**)
24
+
25
+ Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
26
+ should explain the purpose and context of the task. Code snippets should be avoided.
27
+
28
+ ### **Acceptance Criteria**: (The **"what"**)
29
+
30
+ List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
31
+ `- [ ]`) for tracking.
32
+ When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
33
+ than step-by-step implementation details.
34
+ Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
35
+ They should be testable and confirm that the core purpose of the task is achieved.
36
+ **Key Principles for Good ACs:**
37
+
38
+ - **Outcome-Oriented:** Focus on the result, not the method.
39
+ - **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
40
+ - **Clear and Concise:** Unambiguous language.
41
+ - **Complete:** Collectively, ACs should cover the scope of the task.
42
+ - **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
43
+
44
+ - *Good Example:* "- [ ] User can successfully log in with valid credentials."
45
+ - *Good Example:* "- [ ] System processes 1000 requests per second without errors."
46
+ - *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
47
+
48
+ ### Task file
49
+
50
+ Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
51
+ `task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
52
+
53
+ ### Task Breakdown Strategy
54
+
55
+ When breaking down features:
56
+
57
+ 1. Identify the foundational components first
58
+ 2. Create tasks in dependency order (foundations before features)
59
+ 3. Ensure each task delivers value independently
60
+ 4. Avoid creating tasks that block each other
61
+
62
+ ### Additional task requirements
63
+
64
+ - Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
65
+ Each task should represent a single unit of work that can be completed in a single PR.
66
+
67
+ - **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
68
+ previous
69
+ tasks (id < current task id).
70
+
71
+ - When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
72
+ Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
73
+ schema".
74
+ Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
75
+ schema", task 3: "Add API endpoint for user data".
76
+
77
+ ## 3. Recommended Task Anatomy
78
+
79
+ ```markdown
80
+ # task‑42 - Add GraphQL resolver
81
+
82
+ ## Description (the why)
83
+
84
+ Short, imperative explanation of the goal of the task and why it is needed.
85
+
86
+ ## Acceptance Criteria (the what)
87
+
88
+ - [ ] Resolver returns correct data for happy path
89
+ - [ ] Error response matches REST
90
+ - [ ] P95 latency ≤ 50 ms under 100 RPS
91
+
92
+ ## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
93
+
94
+ 1. Research existing GraphQL resolver patterns
95
+ 2. Implement basic resolver with error handling
96
+ 3. Add performance monitoring
97
+ 4. Write unit and integration tests
98
+ 5. Benchmark performance under load
99
+
100
+ ## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
101
+
102
+ - Approach taken
103
+ - Features implemented or modified
104
+ - Technical decisions and trade-offs
105
+ - Modified or added files
106
+ ```
107
+
108
+ ## 6. Implementing Tasks
109
+
110
+ Mandatory sections for every task:
111
+
112
+ - **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
113
+ change after the task is created, **the implementation plan must be added only after putting the task in progress**
114
+ and before starting working on the task.
115
+ - **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
116
+ section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
117
+ concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
118
+ others will read the code to understand the details.
119
+
120
+ **IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
121
+ implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
122
+
123
+ ## 2. Typical Workflow
124
+
125
+ ```bash
126
+ # 1 Identify work
127
+ backlog task list -s "To Do" --plain
128
+
129
+ # 2 Read details & documentation
130
+ backlog task 42 --plain
131
+ # Read also all documentation files in `backlog/docs/` directory.
132
+ # Read also all decision files in `backlog/decisions/` directory.
133
+
134
+ # 3 Start work: assign yourself & move column
135
+ backlog task edit 42 -a @{yourself} -s "In Progress"
136
+
137
+ # 4 Add implementation plan before starting
138
+ backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
139
+
140
+ # 5 Break work down if needed by creating subtasks or additional tasks
141
+ backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
142
+
143
+ # 6 Complete and mark Done
144
+ backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
145
+ ```
146
+
147
+ ### 7. Final Steps Before Marking a Task as Done
148
+
149
+ Always ensure you have:
150
+
151
+ 1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
152
+ 2. ✅ Added an `## Implementation Notes` section documenting your approach
153
+ 3. ✅ Run all tests and linting checks
154
+ 4. ✅ Updated relevant documentation
155
+
156
+ ## 8. Definition of Done (DoD)
157
+
158
+ A task is **Done** only when **ALL** of the following are complete:
159
+
160
+ 1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
161
+ 2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
162
+ 3. **Automated tests** (unit + integration) cover new logic.
163
+ 4. **Static analysis**: linter & formatter succeed.
164
+ 5. **Documentation**:
165
+ - All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
166
+ - Task file **MUST** have an `## Implementation Notes` section added summarising:
167
+ - Approach taken
168
+ - Features implemented or modified
169
+ - Technical decisions and trade-offs
170
+ - Modified or added files
171
+ 6. **Review**: self review code.
172
+ 7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
173
+ 8. **No regressions**: performance, security and licence checks green.
174
+
175
+ ⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
176
+
177
+ ## 9. Handy CLI Commands
178
+
179
+ | Action | Example |
180
+ |-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
181
+ | Create task | `backlog task create "Add OAuth System"` |
182
+ | Create with description | `backlog task create "Feature" -d "Add authentication system"` |
183
+ | Create with assignee | `backlog task create "Feature" -a @sara` |
184
+ | Create with status | `backlog task create "Feature" -s "In Progress"` |
185
+ | Create with labels | `backlog task create "Feature" -l auth,backend` |
186
+ | Create with priority | `backlog task create "Feature" --priority high` |
187
+ | Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
188
+ | Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
189
+ | Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
190
+ | Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
191
+ | Create sub task | `backlog task create -p 14 "Add Login with Google"` |
192
+ | Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
193
+ | List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
194
+ | List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
195
+ | View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
196
+ | View (AI mode) | `backlog task 7 --plain` |
197
+ | Edit | `backlog task edit 7 -a @sara -l auth,backend` |
198
+ | Add plan | `backlog task edit 7 --plan "Implementation approach"` |
199
+ | Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
200
+ | Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
201
+ | Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
202
+ | Archive | `backlog task archive 7` |
203
+ | Create draft | `backlog task create "Feature" --draft` |
204
+ | Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
205
+ | Demote to draft | `backlog task demote <id>` |
206
+
207
+ Full help: `backlog --help`
208
+
209
+ ## 10. Tips for AI Agents
210
+
211
+ - **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
212
+ interactive UI.
213
+ - When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
214
+
215
+ # === BACKLOG.MD GUIDELINES END ===
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/wav/helpful_base_0.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/wav/helpful_base_1.wav filter=lfs diff=lfs merge=lfs -text
38
+ examples/wav/helpful_base_2.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/wav/helpful_base_3.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/wav/helpful_base_4.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/wav/helpful_base_5.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/wav/helpful_base_6.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/wav/helpful_base_7.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/wav/helpful_base_8.wav filter=lfs diff=lfs merge=lfs -text
45
+ examples/wav/helpful_base_9.wav filter=lfs diff=lfs merge=lfs -text
46
+ images/llama-omni2.png filter=lfs diff=lfs merge=lfs -text
47
+ llama_omni2/inference/prompt_en.wav filter=lfs diff=lfs merge=lfs -text
48
+ llama_omni2/inference/prompt_zh.wav filter=lfs diff=lfs merge=lfs -text
49
+ models/Llama-3.1-8B-Omni/images/model.png filter=lfs diff=lfs merge=lfs -text
50
+ tmp/e5fd5a073117d600c1ed49bd412158449e0e001ade31bc971dc1dcb45631c170/Tuesday[[:space:]]at[[:space:]]20-06.wav filter=lfs diff=lfs merge=lfs -text
2505.02625v1.txt ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLaMA-Omni 2: LLM-based Real-time Spoken Chatbot with
2
+ Autoregressive Streaming Speech Synthesis
3
+ Qingkai Fang1,3 , Yan Zhou1,3 , Shoutao Guo1,3 , Shaolei Zhang1,3 , Yang Feng1,2,3 *
4
+ 1
5
+ Key Laboratory of Intelligent Information Processing
6
+ Institute of Computing Technology, Chinese Academy of Sciences (ICT/CAS)
7
+ 2
8
+ Key Laboratory of AI Safety, Chinese Academy of Sciences
9
+ 3
10
+ University of Chinese Academy of Sciences, Beijing, China
11
+ {fangqingkai21b,fengyang}@ict.ac.cn
12
+
13
+ arXiv:2505.02625v1 [cs.CL] 5 May 2025
14
+
15
+ Abstract
16
+
17
+ recognition (ASR) model, an LLM, and a text-tospeech (TTS) model. While this method is relatively straightforward to implement, it suffers from
18
+ several notable limitations. First, errors can accumulate across the different stages of the pipeline.
19
+ Second, the overall response latency tends to be
20
+ high due to the sequential processing of multiple models. Third, the system struggles to capture paralinguistic information present in the input
21
+ speech. To address these limitations, end-to-end
22
+ speech language models (SpeechLMs) have gradually gained more attention, using a single unified
23
+ model to handle the entire process from speech input to output. Overall, end-to-end SpeechLMs can
24
+ be categorized into two types: native and modular. Native SpeechLMs typically discretize speech
25
+ into tokens and employ a GPT-style decoder-only
26
+ Transformer (Radford, 2018) to model both speech
27
+ and text within a unified language model (Zhang
28
+ et al., 2023; Rubenstein et al., 2023; Hassid et al.,
29
+ 2024a). A key advantage of this architecture is
30
+ its ability to leverage vast amounts of unsupervised speech data for pretraining, making it easier to scale up in terms of model parameters and
31
+ data size. This can potentially result in emergent capabilities, such as more human-like speech
32
+ expressiveness (Zeng et al., 2024a; Open-Moss,
33
+ 2025). However, native SpeechLMs typically require large-scale speech datasets (e.g., millions of
34
+ hours) for pretraining (Zeng et al., 2024b; Défossez et al., 2024), which presents challenges in data
35
+ collection and training costs, and may also lead to
36
+ catastrophic forgetting of the model’s text capabilities. In contrast, modular SpeechLMs incorporate
37
+ a speech encoder and a speech decoder around the
38
+ LLM to handle speech understanding and generation (Fang et al., 2025; Wang et al., 2024). The
39
+ advantage of this approach is its ability to leverage
40
+ the inherent capabilities of each module, requiring
41
+ only small-scale fine-tuning (e.g., a few hundred
42
+ or thousand hours of speech data) to align the mod-
43
+
44
+ Real-time, intelligent, and natural speech interaction is an essential part of the next-generation
45
+ human-computer interaction. Recent advancements have showcased the potential of building
46
+ intelligent spoken chatbots based on large language models (LLMs). In this paper, we introduce LLaMA-Omni 2, a series of speech language models (SpeechLMs) ranging from 0.5B
47
+ to 14B parameters, capable of achieving highquality real-time speech interaction. LLaMAOmni 2 is built upon the Qwen2.5 series models, integrating a speech encoder and an autoregressive streaming speech decoder. Despite being trained on only 200K multi-turn speech dialogue samples, LLaMA-Omni 2 demonstrates
48
+ strong performance on several spoken question
49
+ answering and speech instruction following
50
+ benchmarks, surpassing previous state-of-theart SpeechLMs like GLM-4-Voice, which was
51
+ trained on millions of hours of speech data.1
52
+
53
+ 1
54
+
55
+ Introduction
56
+
57
+ Speech, as a critical interface for human-computer
58
+ interaction, can significantly enhance both interaction efficiency and user experience (Clark et al.,
59
+ 2019). In recent years, as large language models (LLMs) like ChatGPT (OpenAI, 2022) have
60
+ demonstrated outstanding performance across various fields, speech interactions with LLMs have
61
+ attracted widespread attention from both academia
62
+ and industry. For instance, GPT-4o (OpenAI, 2024)
63
+ enables real-time, intelligent, and natural speech
64
+ interaction between users and LLMs, heralding the
65
+ advent of a new generation of human-computer
66
+ interaction paradigms.
67
+ To develop a spoken chatbot similar to GPT-4o,
68
+ the traditional approach typically employs a cascaded pipeline comprising an automatic speech
69
+ * Corresponding author: Yang Feng.
70
+ 1
71
+
72
+ Code: https://github.com/ictnlp/LLaMA-Omni2
73
+ Audio Samples: https://llama-omni2.github.io/
74
+
75
+ 1
76
+
77
+ ules. This enables the model to acquire speech interaction capabilities at a relatively low cost, while
78
+ retaining most of its original capability. Moreover,
79
+ modular SpeechLMs can typically generate speech
80
+ guided by textual output, ensuring the intelligence
81
+ of the generated speech.
82
+ In addition to the intelligence of speech, realtime responsiveness and naturalness are also crucial characteristics of spoken chatbots. LLaMAOmni (Fang et al., 2025) uses a non-autoregressive
83
+ (NAR) streaming speech decoder to enable synchronized generation of speech and text, ensuring
84
+ extremely low response latency. However, due
85
+ to the limitations of non-autoregressive models in
86
+ modeling capacity, the generated speech is often
87
+ less natural and fluent. Freeze-Omni (Wang et al.,
88
+ 2024) combines both NAR and autoregressive (AR)
89
+ models for speech generation, resulting in higher
90
+ naturalness of the generated speech. However, it
91
+ can only achieve sentence-level streaming speech
92
+ generation through a simple sentence-split strategy,
93
+ which prevents it from achieving very low response
94
+ latency. To address these challenges, in this paper,
95
+ we introduce LLaMA-Omni 2, a series of modular
96
+ SpeechLMs ranging from 0.5B to 14B. LLaMAOmni 2 adopts Qwen2.5-0.5B/1.5B/3B/7B/14BInstruct models (Team, 2024) as the base LLM,
97
+ and uses Whisper’s encoder (Radford et al., 2023)
98
+ as the speech encoder. For the speech decoder,
99
+ inspired by the state-of-the-art streaming speech
100
+ synthesis model CosyVoice 2 (Du et al., 2024), it
101
+ first includes an autoregressive text-to-speech language model initialized with Qwen2.5-0.5B, which
102
+ generates speech tokens from the LLM output and
103
+ achieves streaming generation through alternating
104
+ read and write operations. The speech tokens are
105
+ then passed through a chunk-aware causal flow
106
+ matching model (Lipman et al., 2023) to generate the mel spectrogram in a streaming manner.
107
+ To train the model, we synthesize 200K multiturn speech-to-speech dialogue samples with diverse input voices and a uniform output voice.
108
+ Experimental results show that LLaMA-Omni 2
109
+ achieves outstanding performance on spoken question answering and speech instruction following
110
+ tasks in both speech-to-text and speech-to-speech
111
+ settings, outperforming both LLaMA-Omni and
112
+ the native SpeechLM GLM-4-Voice (Zeng et al.,
113
+ 2024a), which was trained on millions of hours
114
+ of speech data. We also conducted detailed ablation studies on factors such as LLM parameter size,
115
+ training data scale, speech decoder pretraining, and
116
+
117
+ read-write strategy, to better understand the impact
118
+ of these factors on the overall system performance.
119
+
120
+ 2
121
+
122
+ Model: LLaMA-Omni 2
123
+
124
+ In this section, we introduce the model architecture
125
+ of LLaMA-Omni 2. As shown in Figure 1, the
126
+ core of LLaMA-Omni 2 is an LLM, for which we
127
+ use the Qwen2.5 series models (Team, 2024) due
128
+ to their strong performance across various benchmarks. Next, we will describe how we equip the
129
+ LLM with speech understanding and streaming
130
+ speech generation capabilities. In the following,
131
+ we use MLLM to denote the LLM. For a single-turn
132
+ instruction-response pair, we denote the speech instruction as X, and the text and speech responses
133
+ as Y T and Y S , respectively.
134
+ 2.1
135
+
136
+ Speech Understanding
137
+
138
+ To enable speech understanding, we incorporate
139
+ a speech encoder and a speech adapter before the
140
+ LLM, similar to LLaMA-Omni (Fang et al., 2025).
141
+ Specifically, we use the encoder of Whisper-largev3 (Radford et al., 2023) as the speech encoder,
142
+ which converts the input speech into a sequence of
143
+ representations. The encoded representations are
144
+ then passed into the speech adapter, which consists
145
+ of a downsampling module and a feed-forward network (FFN). The downsampling module concatenates every k consecutive frames along the feature
146
+ dimension, and the concatenated representations
147
+ are further encoded by the FFN. The final output
148
+ representation is then input into the LLM.
149
+ 2.2
150
+
151
+ Streaming Speech Generation
152
+
153
+ To equip the model with streaming speech generation capabilities, we adopt a paradigm similar to
154
+ CosyVoice 2 (Du et al., 2024). First, the speech
155
+ response is converted into discrete tokens using a
156
+ supervised semantic speech tokenizer. Then, an
157
+ autoregressive text-to-speech language model is
158
+ employed to model the streaming generation from
159
+ the LLM output to speech tokens. Finally, a causal
160
+ flow matching model converts speech tokens into
161
+ the mel spectrogram in a streaming manner.
162
+ Speech Tokenizer The speech tokenizer is implemented by inserting a finite scalar quantization
163
+ (FSQ) module (Mentzer et al., 2024) into the encoder of SenseVoice-Large ASR model (An et al.,
164
+ 2024). This module first projects the intermediate
165
+ representations to a low-rank space and discretizes
166
+ them through a rounding operation. Ultimately,
167
+ 2
168
+
169
+
170
+
171
+ latency
172
+
173
+ Gate Fusion Module
174
+
175
+ Flow Matching & Vocoder
176
+
177
+ Large Language Model
178
+ Speech Adaptor
179
+ Speech Encoder
180
+
181
+ +
182
+ <latexit sha1_base64="7CDz+hFii/hnzm/SPcG6JVj1JjA=">AAAB6HicbVDLSgNBEOyNrxhfUY9eBoMgCGFXJHoMevGYgHlAsoTZSW8yZnZ2mZkVQsgXePGgiFc/yZt/4yTZgyYWNBRV3XR3BYng2rjut5NbW9/Y3MpvF3Z29/YPiodHTR2nimGDxSJW7YBqFFxiw3AjsJ0opFEgsBWM7mZ+6wmV5rF8MOME/YgOJA85o8ZK9YteseSW3TnIKvEyUoIMtV7xq9uPWRqhNExQrTuemxh/QpXhTOC00E01JpSN6AA7lkoaofYn80On5MwqfRLGypY0ZK7+npjQSOtxFNjOiJqhXvZm4n9eJzXhjT/hMkkNSrZYFKaCmJjMviZ9rpAZMbaEMsXtrYQNqaLM2GwKNgRv+eVV0rwse5VypX5Vqt5mceThBE7hHDy4hircQw0awADhGV7hzXl0Xpx352PRmnOymWP4A+fzB3UXjLo=</latexit>
183
+
184
+ <latexit sha1_base64="IksX52OSp+tBzewG6ZihRejM6FQ=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69BIvgqSQi1WPRi8cKpi20oWw2m3bpZjfsToRS+hu8eFDEqz/Im//GbZuDVh8MPN6bYWZelAlu0PO+nNLa+sbmVnm7srO7t39QPTxqG5VrygKqhNLdiBgmuGQBchSsm2lG0kiwTjS+nfudR6YNV/IBJxkLUzKUPOGUoJWCvooVDqo1r+4t4P4lfkFqUKA1qH72Y0XzlEmkghjT870MwynRyKlgs0o/NywjdEyGrGepJCkz4XRx7Mw9s0rsJkrbkugu1J8TU5IaM0kj25kSHJlVby7+5/VyTK7DKZdZjkzS5aIkFy4qd/65G3PNKIqJJYRqbm916YhoQtHmU7Eh+Ksv/yXti7rfqDfuL2vNmyKOMpzAKZyDD1fQhDtoQQAUODzBC7w60nl23pz3ZWvJKWaO4Recj2/u8I7J</latexit>
185
+
186
+ Text-to-Speech Language Model
187
+
188
+ <latexit sha1_base64="IksX52OSp+tBzewG6ZihRejM6FQ=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69BIvgqSQi1WPRi8cKpi20oWw2m3bpZjfsToRS+hu8eFDEqz/Im//GbZuDVh8MPN6bYWZelAlu0PO+nNLa+sbmVnm7srO7t39QPTxqG5VrygKqhNLdiBgmuGQBchSsm2lG0kiwTjS+nfudR6YNV/IBJxkLUzKUPOGUoJWCvooVDqo1r+4t4P4lfkFqUKA1qH72Y0XzlEmkghjT870MwynRyKlgs0o/NywjdEyGrGepJCkz4XRx7Mw9s0rsJkrbkugu1J8TU5IaM0kj25kSHJlVby7+5/VyTK7DKZdZjkzS5aIkFy4qd/65G3PNKIqJJYRqbm916YhoQtHmU7Eh+Ksv/yXti7rfqDfuL2vNmyKOMpzAKZyDD1fQhDtoQQAUODzBC7w60nl23pz3ZWvJKWaO4Recj2/u8I7J</latexit>
189
+
190
+ 1
191
+ <latexit sha1_base64="yWApgEffzdEH57mYnQQzN7vgc1w=">AAAB6XicbVBNS8NAEJ34WetX1aOXxSJ4sSQi1WPRi8cq9gPaUDbbSbt0swm7G6GE/gMvHhTx6j/y5r9x2+agrQ8GHu/NMDMvSATXxnW/nZXVtfWNzcJWcXtnd2+/dHDY1HGqGDZYLGLVDqhGwSU2DDcC24lCGgUCW8Hoduq3nlBpHstHM07Qj+hA8pAzaqz04J33SmW34s5AlomXkzLkqPdKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZpdOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymb5M+V8iMGFtCmeL2VsKGVFFmbDhFG4K3+PIyaV5UvGqlen9Zrt3kcRTgGE7gDDy4ghrcQR0awCCEZ3iFN2fkvDjvzse8dcXJZ47gD5zPH+d4jPc=</latexit>
192
+
193
+ <latexit sha1_base64="tdy5cBUx22e49sInllEMc7AaEZY=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoPgKeyKRI9BLx4jmAckS5idzCZj5rHMzAphyT948aCIV//Hm3/jJNmDJhY0FFXddHdFCWfG+v63V1hb39jcKm6Xdnb39g/Kh0cto1JNaJMornQnwoZyJmnTMstpJ9EUi4jTdjS+nfntJ6oNU/LBThIaCjyULGYEWye1eoYNBe6XK37VnwOtkiAnFcjR6Je/egNFUkGlJRwb0w38xIYZ1pYRTqelXmpogskYD2nXUYkFNWE2v3aKzpwyQLHSrqRFc/X3RIaFMRMRuU6B7cgsezPxP6+b2vg6zJhMUkslWSyKU46sQrPX0YBpSiyfOIKJZu5WREZYY2JdQCUXQrD88ippXVSDWrV2f1mp3+RxFOEETuEcAriCOtxBA5pA4BGe4RXePOW9eO/ex6K14OUzx/AH3ucPn/GPLg==</latexit>
194
+
195
+ Stage I(a)
196
+ FFN
197
+
198
+ Emb
199
+ Certainly!
200
+
201
+ TTS Language Model
202
+
203
+ Gate Fusion
204
+ Certainly!
205
+
206
+ Writing
207
+
208
+ high …
209
+
210
+ a
211
+
212
+ Stage I(b)
213
+
214
+ TTS Language Model
215
+
216
+ Large Language Model
217
+
218
+ Gate Fusion
219
+ Speech Representations
220
+
221
+ Speech Adaptor
222
+
223
+ LLM Hidden States
224
+
225
+ Speech Encoder
226
+
227
+ Fused Representations
228
+ Speech Tokens
229
+ Ignore Tokens
230
+
231
+ (Hey! Can you give me some
232
+ advices on writing NLP papers?)
233
+
234
+ Large Language Model
235
+ Speech Adaptor
236
+ Speech Encoder
237
+ Stage II
238
+
239
+ Figure 1: Left: Model architecture of LLaMA-Omni 2. Right: Illustration of the two-stage training strategy.
240
+
241
+ the speech response Y S is converted into a token
242
+ U ], with 25 tokens per
243
+ sequence Y U = [y1U , . . . , yM
244
+ second, where each token yiU ∈ {K ∈ N | 0 ≤
245
+ K < 6561}. We use the pretraiend speech tokenizer in CosyVoice 2.
246
+
247
+ MTTS , while also obtaining the text embeddings:
248
+ ehidden
249
+ = FFN(hi ),
250
+ i
251
+
252
+ (1)
253
+
254
+ eemb
255
+ = Emb(yiT ),
256
+ i
257
+
258
+ (2)
259
+
260
+ where Emb(·) is the embedding layer of MTTS . Afterward, we use an element-wise gate fusion mechanism to combine both representations. Specifically,
261
+ we compute the gate gi as follows:
262
+
263
+ Text-to-Speech Language Model After converting the speech response into discrete tokens, we
264
+ use a decoder-only Transformer (Vaswani, 2017)
265
+ to model the conditional language model from
266
+ the LLM output to the speech tokens, denoted as
267
+ MTTS . It is initialized with Qwen2.5-0.5B, and
268
+ its vocabulary is extended as V′ = V ∪ {< i >|
269
+ i ∈ N, 0 ≤ i < 6561}, where V is the original
270
+ vocabulary. This extension enables the model to
271
+ generate speech tokens.
272
+
273
+ 
274
+ 
275
+ 
276
+ gi = σ Wg ehidden
277
+ ∥ eemb
278
+ + bg ,
279
+ i
280
+ i
281
+
282
+ (3)
283
+
284
+ where ∥ denotes concatenation, σ is the sigmoid
285
+ function, and Wg ∈ R2d×d and bg ∈ Rd are the
286
+ weight and bias parameters of the gate, and d is
287
+ the embedding size of MTTS . Finally, the fused
288
+ representation is computed as:
289
+
290
+ The input to MTTS comes from the output of the
291
+ LLM. Specifically, the LLM output consists of two
292
+ parts: continuous hidden states and text tokens sampled from the hidden states. The former contains
293
+ contextual information, while the latter provides
294
+ precise textual content. We aim to use both as inputs to the text-to-speech language model. This allows the model to both consider the current context
295
+ and ensure better alignment with the text response
296
+ when generating speech tokens. During training,
297
+ the LLM is trained with teacher forcing, so its output hidden states are denoted as H = [h1 , ..., hN ],
298
+ T ). The corresponding
299
+ where hi = MLLM (X, Y<i
300
+ T
301
+ T ]. We first
302
+ text is the ground truth Y = [y1T , ..., yN
303
+ use a 2-layer feed-forward network (FFN) to map
304
+ the hidden states to the embedding dimension of
305
+
306
+ ci = gi ⊙ ehidden
307
+ + (1 − gi ) ⊙ eemb
308
+ i
309
+ i ,
310
+
311
+ (4)
312
+
313
+ where ⊙ denotes element-wise multiplication. This
314
+ fused representations C = [c1 , ..., cN ] are then
315
+ passed to MTTS for generating speech tokens.
316
+ To achieve streaming generation, i.e., to generate
317
+ speech tokens simultaneously during the LLM’s
318
+ output process, we adopt a “Read-R-Write-W”
319
+ strategy, similar to CosyVoice 2. Specifically, we
320
+ mix the fused representation C and the speech tokens Y U at a predefined ratio R : W. For every R
321
+ fused representations read in, the model generates
322
+ W speech tokens. Once all fused representations
323
+ are read, the model continues to generate the remaining speech tokens until completion. During
324
+ 3
325
+
326
+ training, cross-entropy loss is computed only for
327
+ the generated speech tokens as follows:
328
+
329
+ 2.4
330
+
331
+ 2.3
332
+
333
+ 3
334
+
335
+ Inference
336
+
337
+ During inference, the LLM autoregressively generates the text response based on the speech instrucM
338
+ X
339
+ tion. After generating R text tokens, its hidden
340
+ U
341
+ LTTS = −
342
+ log P (yiU |C≤min(⌊ i−1 +1⌋·R,N ) , Y<i
343
+ ), states and the corresponding decoded text are fed
344
+ W
345
+ i=1
346
+ (5) into the gate fusion module and MTTS to generate
347
+ where C≤min(⌊ i−1 +1⌋·R,N ) denotes the fused rep- W speech tokens, which are then passed through
348
+ W
349
+ the flow matching model and the vocoder to syntheresentations that have already been read.
350
+ size a speech chunk. In this way, text and speech
351
+ responses can be generated simultaneously. The
352
+ Flow Matching Model The speech tokens gen- response latency for the first synthesized speech
353
+ erated by MTTS are further processed by a chunk- chunk can be calculated as:
354
+ aware causal flow matching model (Lipman et al.,
355
+ 2023) to synthesize the mel spectrogram in a
356
+ Ttotal = TLLM (R)+TTTS (W)+TFM (W)+TVoc (2W),
357
+ streaming manner. Every time W speech tokens
358
+ (6)
359
+ are generated, they are treated as a chunk for
360
+ where TLLM (R) and TTTS (W) represent the time
361
+ mel spectrogram synthesis. The synthesized mel
362
+ required by the MLLM and MTTS models to genspectrogram is then passed through a HiFi-GAN
363
+ erate R and W tokens, respectively. TFM (W) and
364
+ vocoder (Kong et al., 2020) to generate the final
365
+ TVoc (2W) represent the decoding times of the flow
366
+ waveform. We use the pretrained flow matching
367
+ matching model and vocoder when the inputs are
368
+ model and vocoder in CosyVoice 2.
369
+ W and 2W tokens2 , respectively.
370
+ Training
371
+
372
+ Data Construction
373
+
374
+ In this section, we introduce the process of constructing multi-turn speech-to-speech dialogue data.
375
+ Our data is an extension of the InstructS2S-200K
376
+ dataset introduced in Fang et al. (2025), which
377
+ contains 200K single-turn instruction-following
378
+ samples designed for speech interaction scenarios.
379
+ These samples are derived from the Alpaca (Taori
380
+ et al., 2023) and UltraChat (Ding et al., 2023)
381
+ datasets through rewriting using LLMs. Specifically, for each sample, we first sample the number of turns from a Poisson distribution: N ∼
382
+ Poisson(λ = 2), then clip N to the range of 1 to 5.
383
+ Next, we use the Llama-3.3-70B-Instruct3 (Dubey
384
+ et al., 2024) model to iteratively generate the dialog.
385
+ For the i-th turn, the instruction and response are
386
+ generated based on the dialogue history of previous
387
+ i − 1 turns. In this way, we obtain 200K multi-turn
388
+ text dialog samples.
389
+ Next, we need to convert the text dialogue into
390
+ speech. To simulate real-world applications, we
391
+ aim to have varied voices for the instruction, while
392
+ maintaining a consistent voice for the response.
393
+ For each multi-turn dialogue, we first use the fishspeech-1.54 model (Liao et al., 2024) to synthesize
394
+
395
+ The training of LLaMA-Omni 2 relies solely on
396
+ 200K multi-turn speech-to-speech dialogue data
397
+ (we will describe how this is synthesized in Section 3) and does not use any ASR or TTS data. We
398
+ find that it is sufficient to achieve excellent performance while minimizing training costs. Specifically, the training process consists of two stages, as
399
+ shown in Figure 1.
400
+ Stage I In Stage I training, we train the speechto-text and text-to-speech components separately.
401
+ The training data consists of <speech instruction,
402
+ text response> pairs and <text response, speech response> pairs from the multi-turn speech-to-speech
403
+ dialogue data. Specifically, for the speech-to-text
404
+ part (Stage I(a)), we freeze the speech encoder
405
+ and train the speech adapter and LLM with crossentropy loss. For the text-to-speech part (Stage
406
+ I(b)), we train the text-to-speech language model
407
+ with cross-entropy loss. Note that during this stage,
408
+ the gate fusion module is not trained, and only text
409
+ embeddings are input into MTTS .
410
+ Stage II In Stage II, we train the model’s speechto-speech generation capability with speech-tospeech dialogue data. During this stage, we freeze
411
+ the speech encoder, speech adapter, and LLM, and
412
+ only train the gate fusion module and MTTS .
413
+
414
+ 2
415
+
416
+ The length of the mel spectrogram is twice that of the
417
+ speech tokens (50 Hz vs. 25 Hz).
418
+ 3
419
+ https://huggingface.co/meta-llama/Llama-3.
420
+ 3-70B-Instruct
421
+ 4
422
+ https://huggingface.co/fishaudio/
423
+
424
+ 4
425
+
426
+ 4.2
427
+
428
+ a short prompt (e.g., "This is a randomly generated
429
+ voice") with a random voice. Then, we use the synthesized speech as the prompt for the CosyVoice20.5B5 model, which synthesize the instruction into
430
+ speech while simultaneously cloning the voice.
431
+ This ensures consistency in the voice across different turns of the dialogue, while maintaining diversity across dialogues. For all responses, we use
432
+ a uniform voice as the prompt and then synthesize
433
+ the speech using the CosyVoice2-0.5B model.
434
+
435
+ 4
436
+
437
+ Experiments
438
+
439
+ 4.1
440
+
441
+ Experimental Setups
442
+
443
+ Evaluation
444
+
445
+ Our evaluation includes two tasks: spoken question answering and speech instruction following.
446
+ For both tasks, we evaluate the model’s speech-totext and speech-to-speech capabilities. The speechto-speech evaluation is done by transcribing the
447
+ speech response into text using the Whisper-largev3 model, and then applying the same evaluation
448
+ method as used for speech-to-text evaluation. In
449
+ all experiments, we use greedy search for the LLM
450
+ to ensure stable results. For the text-to-speech language model, we use sampling with temperature set
451
+ to 1.0, as we find that using greedy search causes
452
+ the model to fall into repetition.
453
+
454
+ Model Configuration We use the encoder of
455
+ Whisper-large-v3 as the speech encoder. The
456
+ speech adapter first performs a 5× downsampling, followed by a FFN with an intermediate
457
+ dimension of 2048. For the LLM, we select
458
+ the Qwen2.5 series models, including Qwen2.50.5B/1.5B/3B/7B/14B-Instruct models. We refer
459
+ to the corresponding models as LLaMA-Omni20.5B/1.5B/3B/7B/14B in the following sections.
460
+ For the text-to-speech language model, we initialize it with the Qwen2.5-0.5B model and set the
461
+ read-write strategy with R = 3 and W = 10.
462
+ We will discuss the impact of these hyperparameters on speech quality and response latency later.
463
+ The speech tokenizer, flow matching model, and
464
+ vocoder are directly taken from CosyVoice 2.
465
+
466
+ Spoken Question Answering The speech question answering (SpokenQA) task involves asking
467
+ the model spoken questions, then checking whether
468
+ the reference answer appears in the model’s response, and calculating the accuracy. We evaluate our model on two benchmarks: Llama Questions6 (Nachmani et al., 2024) and Web Questions7 (Berant et al., 2013). Since the questions in
469
+ the Web Questions dataset are in text form, we use
470
+ CosyVoice2-0.5B to synthesize them into speech.
471
+ Speech Instruction Following For the speech
472
+ instruction following task, we follow the settings
473
+ in Fang et al. (2025), selecting the helpful_base
474
+ and vicuna subsets from the Alpaca-Eval8 (Li et al.,
475
+ 2023) dataset, excluding math and code-related instructions. The remaining 199 instructions are then
476
+ synthesized into speech for evaluation. Following Fang et al. (2025), we evaluate the model using
477
+ the following metrics:
478
+ ChatGPT Score: To evaluate the model’s ability to follow instructions, we use GPT-4o (OpenAI,
479
+ 2024) to score the model’s responses. It considers
480
+ factors such as helpfulness, relevance, fluency, and
481
+ suitability for speech interaction scenarios, and assigns a single score between 1 and 5. The detailed
482
+ prompt can be found in Appendix A.
483
+ ASR-WER: To assess the consistency between
484
+ model’s text and speech responses, we use Whisperlarge-v3 to transcribe the speech response into text,
485
+ and calculate the word error rate (WER) between
486
+ the transcribed text and text response. We perform
487
+
488
+ Training Details We use the 200K multi-turn
489
+ speech-to-speech dialogue data from Section 3 for
490
+ two-stage training. In Stage I(a), we freeze the
491
+ speech encoder and train all parameters of the
492
+ speech adaptor and LLM. The batch size is 32,
493
+ and we train for 3 epochs with a peak learning
494
+ rate of 5e-5. In Stage I(b), we train the text-tospeech language model with a batch size of 32 for
495
+ 5 epochs and a peak learning rate of 5e-4. In Stage
496
+ II, we freeze the speech encoder, speech adaptor,
497
+ and LLM, and train the remaining components with
498
+ a batch size of 32 for 1 epoch and a peak learning
499
+ rate of 1e-3. For all stages, we use a warmup strategy for the first 3% of steps and a cosine annealing
500
+ learning rate scheduler. The LLaMA-Omni2-14B
501
+ model is trained on 4 NVIDIA H800 GPUs, while
502
+ other models are trained on 4 NVIDIA L40 GPUs.
503
+
504
+
505
+ Table 1: Results on speech question answering and speech instruction following benchmarks. S2T and S2S represent
506
+ speech-to-text and speech-to-speech, respectively. We set R = 3 and W = 10 for all LLaMA-Omni2 series models.
507
+
508
+ text normalization9 before calculating the WER.
509
+ UTMOS: To evaluate the naturalness of the generated speech, we use the UTMOS model10 (Saeki
510
+ et al., 2022) to predict the mean opinion score
511
+ (MOS) of the generated speech.
512
+ Latency: We measure the time from receiving
513
+ the speech instruction to generating the first speech
514
+ chunk on a single NVIDIA L40 GPU.
515
+ 4.3
516
+
517
+ 5
518
+
519
+ Results and Analysis
520
+
521
+ 5.1
522
+
523
+ Main Results
524
+
525
+ Table 1 presents the main results on the speech
526
+ question answering and speech instruction following benchmarks.
527
+ Spoken Question Answering For the SpokenQA
528
+ task, we observe that: (1) For models with similar
529
+ parameter sizes, LLaMA-Omni2-7B outperforms
530
+ both GLM-4-Voice and LLaMA-Omni in both S2T
531
+ and S2S settings. Notably, our model significantly
532
+ reduces the gap between S2T and S2S performance. For example, on the Web Questions benchmark, GLM-4-Voice drops by 16.3 (32.2→15.9),
533
+ LLaMA-Omni drops by 9.7 (33.4→23.7), while
534
+ LLaMA-Omni2-7B only drops by 3.2 (34.5→31.3),
535
+ demonstrating that our approach largely improves
536
+ speech generation capabilities. (2) For models with
537
+ varying parameter sizes, we observe that accuracy
538
+ increases as the LLM size grows, indicating that
539
+ LLaMA-Omni 2 effectively leverages the LLM’s
540
+ inherent capabilities. For smaller models, LLaMAOmni2-1.5B/3B exceeds the accuracy of GLM-4Voice and LLaMA-Omni in the S2S setting, making them suitable choices for edge devices. For
541
+ larger models, we observe a significant accuracy
542
+ improvement with LLaMA-Omni2-14B compared
543
+ to LLaMA-Omni2-7B, highlighting the potential
544
+ of our approach for scaling to larger models.
545
+
546
+ Baseline Systems
547
+
548
+ We primarily compare LLaMA-Omni 2 with the
549
+ following baseline systems:
550
+ LLaMA-Omni (Fang et al., 2025): One of the
551
+ earliest SpeechLMs that achieves real-time speech
552
+ interaction, by using a CTC-based (Graves et al.,
553
+ 2006) streaming speech decoder to simultaneously
554
+ generate text and speech units. The generated units
555
+ are fed into the vocoder for streaming synthesis in
556
+ fixed-size chunks. We set the chunk size Ω = 40.
557
+ GLM-4-Voice (Zeng et al., 2024a): The current state-of-the-art native SpeechLM, pretrained
558
+ on millions of hours of speech data. It enables realtime speech interaction by alternately generating
559
+ text and speech tokens in a fixed ratio of 13:26.
560
+ The generated speech tokens are input into a flow
561
+ matching model with a fixed chunk size.
562
+ In addition, we also borrow some results
563
+ from Zeng et al. (2024a), including results of
564
+ TWIST (Hassid et al., 2024b), SpeechGPT (Zhang
565
+ et al., 2023), Spectron (Nachmani et al., 2024), and
566
+ Moshi (Défossez et al., 2024).
567
+
568
+ Speech Instruction Following For the speech
569
+ instruction following task, we observe that:
570
+ (1) LLaMA-Omni2-3B/7B/14B outperforms both
571
+ GLM-4-Voice and LLaMA-Omni in the S2T and
572
+ S2S settings, demonstrating the strong instructionfollowing capabilities of our models. (2) Similar
573
+
574
+ 9
575
+ https://github.com/openai/whisper/blob/main/
576
+ whisper/normalizers/english.py
577
+ 10
578
+ https://github.com/tarepan/SpeechMOS
579
+ 798.99
580
+ -
581
+
582
+ Table 4: Ablation study on the read/write strategy with
583
+ LLaMA-Omni2-7B. “Offline” means generating speech
584
+ tokens only after receiving the complete input, and then
585
+ synthesizing all speech tokens into waveform at once.
586
+
587
+ fusing them with the gate fusion module.
588
+
589
+ Table 3: Ablation study on different TTS pretraining
590
+ strategies with LLaMA-Omni2-7B.
591
+
592
+ TTS Pretraining Our text-to-speech language
593
+ model is initialized with the Qwen2.5-0.5B model
594
+ and undergoes streaming TTS pretraining using
595
+ text-speech pairs from speech dialogue data in
596
+ Stage I(b) (R = 3, W = 10). We also explore
597
+ several other strategies, as shown in Table 3. “Offline TTS” refers to pretraining with the offline TTS
598
+ task on top of Qwen2.5-0.5B, which shows a slight
599
+ performance drop compared to the streaming TTS
600
+ pretraining. “Text Pretrained” refers to directly initializing with Qwen2.5-0.5B (with the extended
601
+ vocabulary including speech tokens), and we observe a significant performance decline. “Scratch”
602
+ refers to a randomly initialized model, whose loss
603
+ fails to converge within a short period. These experiments demonstrate the importance of pretraining
604
+ for the TTS language model.
605
+
606
+ to the results on SpokenQA benchmarks, we observe that model performance improves as the LLM
607
+ size increases, with LLaMA-Omni2-14B achieving
608
+ significantly better performance. (3) The models’
609
+ ASR-WER is generally low, significantly lower
610
+ than previous models, proving that our models
611
+ maintain strong consistency between the text and
612
+ speech responses. (4) Regarding speech quality,
613
+ thanks to the CosyVoice 2’s strong causal flow
614
+ matching model, our models achieve good UTMOS
615
+ scores under streaming synthesis, significantly outperforming the baseline models. (5) The latency
616
+ of LLaMA-Omni 2 is around 600ms. Although it
617
+ is slightly higher than LLaMA-Omni, it still meets
618
+ the requirements for real-time interaction and is
619
+ significantly lower than that of GLM-4-Voice.
620
+ 5.2
621
+
622
+ W
623
+
624
+ 1
625
+ 5
626
+ 2 10
627
+ 3 10
628
+ 3 15
629
+ 4 15
630
+ 5 20
631
+ Offline
632
+
633
+ Read/Write Strategy The read/write strategies
634
+ of the TTS language model is a key factor influencing performance, primarily affecting the speech
635
+ quality and system response latency. As shown
636
+ in Table 4, we explore different combinations of
637
+ R and W. First, we observe that when R = 3
638
+ and W = 10, the ASR-WER is the lowest, indicating the best alignment between speech and text
639
+ responses. As for the UTMOS score, we find that
640
+ it is primarily determined by W, as W represents
641
+ the chunk size of speech tokens input to the flow
642
+ matching model, with larger chunk sizes leading to
643
+ better speech quality. Regarding response latency,
644
+ it is jointly determined by R and W, as shown
645
+ in Equation 6. Without any engineering optimizations, LLaMA-Omni2-7B can achieve a latency
646
+ below 500ms. We choose R = 3 and W = 10 in
647
+ our main experiments because it provides a good
648
+ trade-off across all aspects.
649
+
650
+ Ablation Studies
651
+
652
+ To understand the impact of different factors on
653
+ overall performance, we conduct a series of ablation studies on the LLaMA-Omni2-7B model.
654
+ Gate Fusion Module Table 2 shows the ablation
655
+ study on the gate fusion module. Gate fusion module allows the model to adaptively fuse LLM hidden states and text embeddings, considering both
656
+ contextual information and textual content. When
657
+ the gate fusion module is removed and the two components are simply added together (ehidden
658
+ + eemb
659
+ i
660
+ i )
661
+ as input to the text-to-speech language model, we
662
+ observe a decrease in performance. Further removing the text embedding and only inputting the
663
+ hidden states (ehidden
664
+ ) results in a further perfori
665
+ mance decline. This validates the effectiveness of
666
+ adding text embeddings as input and adaptively
667
+ 7
668
+
669
+ Table 5: Results under different training data sizes with LLaMA-Omni2-7B.
670
+
671
+ 5.3
672
+
673
+ Effects of the Training Data Sizes
674
+
675
+ IntrinsicVoice (Zhang et al., 2024c) proposes a
676
+ GroupFormer architecture to shorten speech length
677
+ to be closer to that of text. In contrast to native SpeechLMs, modular SpeechLMs add speechrelated modules on top of LLMs. Early works
678
+ achieve speech understanding tasks by combining
679
+ speech encoders with LLMs, but are unable to perform speech generation (Wu et al., 2023; Wang
680
+ et al., 2023; Chu et al., 2023; Yu et al., 2024; Ma
681
+ et al., 2024b; Hono et al., 2024; Chen et al., 2024b;
682
+ Tang et al., 2024; Chu et al., 2024; Fathullah et al.,
683
+ 2024). To achieve speech generation, LLaMAOmni (Fang et al., 2025), Freeze-Omni (Wang
684
+ et al., 2024), and OpenOmni (Luo et al., 2025) add
685
+ a speech decoder after LLMs. Mini-Omni (Xie
686
+ and Wu, 2024) and SLAM-Omni (Chen et al.,
687
+ 2024a) enable LLMs to generate speech tokens
688
+ simultaneously while generating text tokens. The
689
+ most related work to ours is the concurrent work
690
+ Minmo (Chen et al., 2025), which also adopts an
691
+ autoregressive streaming speech decoder similar
692
+ to CosyVoice 2. In comparison, Minmo is trained
693
+ on 1.4M hours of data, while we train on only a
694
+ few thousand hours of data, providing a more efficient training solution. Additionally, we conduct
695
+ detailed ablation studies on LLM sizes, read-write
696
+ strategies, and model architecture to offer a more
697
+ comprehensive understanding of the model.
698
+
699
+ We explore the impact of different training data
700
+ sizes on performance. As shown in Table 5, we
701
+ first observe that, with the same number of training samples, multi-turn dialogue data consistently
702
+ achieves better results across all benchmarks compared to single-turn dialogue data, highlighting
703
+ the effectiveness of multi-turn dialogue data for
704
+ training. Additionally, for different training data
705
+ sizes, we observe that as the data size increases,
706
+ the model’s performance improves, gradually stabilizing at 200K training samples. This indicates
707
+ that our 200K multi-turn dialogue data is generally
708
+ sufficient while ensuring efficient training.
709
+
710
+ 6
711
+
712
+ Related Work
713
+
714
+ With the rapid development of LLMs, SpeechLMs
715
+ have gained widespread attention in recent
716
+ years (Cui et al., 2024; Ji et al., 2024), aiming to
717
+ endow LLMs with the ability to understand or generate speech. Generally speaking, SpeechLMs can
718
+ be divided into two categories: native SpeechLMs
719
+ and modular SpeechLMs. Native SpeechLMs refer to decoder-only Transformer models capable
720
+ of directly inputting and outputting speech tokens.
721
+ Some early works include SpeechGPT (Zhang
722
+ et al., 2023, 2024a), AudioPaLM (Rubenstein
723
+ et al., 2023), and TWIST (Hassid et al., 2024a).
724
+ These models first convert speech into discrete
725
+ tokens, then extend the vocabulary of pretrained
726
+ LLMs to include these tokens, and finally train the
727
+ LLMs using a large amount of speech or speechtext pair data. Spirit-LM (Nguyen et al., 2024)
728
+ and GLM-4-Voice (Zeng et al., 2025, 2024a) propose training models using speech-text interleaved
729
+ data to encourage cross-modal knowledge transfer.
730
+ Moshi (Défossez et al., 2024), OmniFlatten (Zhang
731
+ et al., 2024b) and LSLM (Ma et al., 2024a) propose models capable of full-duplex conversations.
732
+
733
+ 7
734
+
735
+ Conclusion
736
+
737
+ In this paper, we introduce LLaMA-Omni 2, a series of speech language models ranging from 0.5B
738
+ to 14B parameters, designed to enable real-time,
739
+ high-quality speech interaction. LLaMA-Omni
740
+ 2 achieves streaming speech generation by integrating an autoregressive text-to-speech language
741
+ model and a causal flow matching model. Experimental results on spoken question answering
742
+ and speech instruction following tasks show that
743
+ 8
744
+
745
+ LLaMA-Omni 2 outperforms previous state-of-theart speech language models, including LLaMAOmni and GLM-4-Voice. Additionally, LLaMAOmni 2 can achieve latency under 600ms, meeting
746
+ real-time interaction requirements. We also conduct detailed ablation studies to understand the
747
+ impact of various factors on overall performance.
748
+ In the future, we will explore enhancing LLaMAOmni 2 to generate more human-like speech, incorporating features such as emotion and dialects.
749
+
750
+ Wenxi Chen, Ziyang Ma, Ruiqi Yan, Yuzhe Liang, Xiquan Li, Ruiyang Xu, Zhikang Niu, Yanqiao Zhu,
751
+ Yifan Yang, Zhanxun Liu, et al. 2024a. Slamomni: Timbre-controllable voice interaction system with single-stage training. arXiv preprint
752
+ arXiv:2412.15649.
753
+ Xi Chen, Songyang Zhang, Qibing Bai, Kai Chen, and
754
+ Satoshi Nakamura. 2024b. LLaST: Improved endto-end speech translation system leveraged by large
755
+ language models. In Findings of the Association for
756
+ Computational Linguistics ACL 2024, pages 6976–
757
+ 6987, Bangkok, Thailand and virtual meeting. Association for Computational Linguistics.
758
+
759
+ Limitations
760
+
761
+ Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei,
762
+ Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng
763
+ He, Junyang Lin, Chang Zhou, and Jingren Zhou.
764
+ 2024. Qwen2-audio technical report. arXiv preprint
765
+ arXiv:2407.10759.
766
+
767
+ One limitation of our model is that currently it
768
+ cannot generate speech responses with different
769
+ styles (such as emotion or speech rate) based on
770
+ the content of the input speech or underlying paralinguistic information, as we have only trained on
771
+ conventional speech-to-speech dialogue data. However, we believe this functionality can be achieved
772
+ through a data-driven approach, as our model is
773
+ end-to-end trained and could acquire this capability after further training with suitable data. We plan
774
+ to explore this in the future.
775
+
776
+ Yunfei Chu, Jin Xu, Xiaohuan Zhou, Qian Yang, Shiliang Zhang, Zhijie Yan, Chang Zhou, and Jingren
777
+ Zhou. 2023. Qwen-audio: Advancing universal
778
+ audio understanding via unified large-scale audiolanguage models. arXiv preprint arXiv:2311.07919.
779
+ Leigh Clark, Philip Doyle, Diego Garaialde, Emer
780
+ Gilmartin, Stephan Schlögl, Jens Edlund, Matthew
781
+ Aylett, João Cabral, Cosmin Munteanu, Justin Edwards, and Benjamin R Cowan. 2019. The state of
782
+ speech in HCI: Trends, themes and challenges. Interacting with Computers, 31(4):349–371.
783
+
784
+ Ethical Considerations
785
+ Since LLaMA-Omni 2 is built on LLMs, it carries
786
+ some of the same risks as LLMs, such as the potential for factual errors or other hallucination issues
787
+ in its outputs. We recommend that the model’s
788
+ outputs be checked in practical use to ensure they
789
+ comply with the required standards.
790
+
791
+ Wenqian Cui, Dianzhi Yu, Xiaoqi Jiao, Ziqiao Meng,
792
+ Guangyan Zhang, Qichao Wang, Yiwen Guo, and Irwin King. 2024. Recent advances in speech language
793
+ models: A survey. arXiv preprint arXiv:2410.03751.
794
+ Alexandre Défossez, Laurent Mazaré, Manu Orsini,
795
+ Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard
796
+ Grave, and Neil Zeghidour. 2024. Moshi: a speechtext foundation model for real-time dialogue. Technical report.
797
+
798
+ References
799
+ Keyu An, Qian Chen, Chong Deng, Zhihao Du,
800
+ Changfeng Gao, Zhifu Gao, Yue Gu, Ting He,
801
+ Hangrui Hu, Kai Hu, et al. 2024. Funaudiollm: Voice
802
+ understanding and generation foundation models for
803
+ natural interaction between humans and llms. arXiv
804
+ preprint arXiv:2407.04051.
805
+
806
+ Ning Ding, Yulin Chen, Bokai Xu, Yujia Qin, Zhi
807
+ Zheng, Shengding Hu, Zhiyuan Liu, Maosong Sun,
808
+ and Bowen Zhou. 2023. Enhancing chat language
809
+ models by scaling high-quality instructional conversations. arXiv preprint arXiv:2305.14233.
810
+ Zhihao Du, Yuxuan Wang, Qian Chen, Xian Shi, Xiang
811
+ Lv, Tianyu Zhao, Zhifu Gao, Yexin Yang, Changfeng
812
+ Gao, Hui Wang, et al. 2024. Cosyvoice 2: Scalable
813
+ streaming speech synthesis with large language models. arXiv preprint arXiv:2412.10117.
814
+
815
+ Jonathan Berant, Andrew Chou, Roy Frostig, and Percy
816
+ Liang. 2013. Semantic parsing on Freebase from
817
+ question-answer pairs. In Proceedings of the 2013
818
+ Conference on Empirical Methods in Natural Language Processing, pages 1533–1544, Seattle, Washington, USA. Association for Computational Linguistics.
819
+
820
+ Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey,
821
+ Abhishek Kadian, Ahmad Al-Dahle, Aiesha Letman,
822
+ Akhil Mathur, Alan Schelten, Amy Yang, Angela
823
+ Fan, et al. 2024. The llama 3 herd of models. arXiv
824
+ preprint arXiv:2407.21783.
825
+
826
+ Qian Chen, Yafeng Chen, Yanni Chen, Mengzhe Chen,
827
+ Yingda Chen, Chong Deng, Zhihao Du, Ruize Gao,
828
+ Changfeng Gao, Zhifu Gao, et al. 2025. Minmo: A
829
+ multimodal large language model for seamless voice
830
+ interaction. arXiv preprint arXiv:2501.06282.
831
+
832
+ Qingkai Fang, Shoutao Guo, Yan Zhou, Zhengrui Ma,
833
+ Shaolei Zhang, and Yang Feng. 2025. LLaMA-omni:
834
+
835
+ 9
836
+
837
+ Seamless speech interaction with large language models. In The Thirteenth International Conference on
838
+ Learning Representations.
839
+
840
+ for advanced multilingual text-to-speech synthesis.
841
+ Preprint, arXiv:2411.01156.
842
+ Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, and Matthew Le. 2023. Flow matching for generative modeling. In The Eleventh International Conference on Learning Representations.
843
+
844
+ Yassir Fathullah, Chunyang Wu, Egor Lakomkin, Ke Li,
845
+ Junteng Jia, Yuan Shangguan, Jay Mahadeokar,
846
+ Ozlem Kalinli, Christian Fuegen, and Mike Seltzer.
847
+ 2024. Audiochatllama: Towards general-purpose
848
+ speech abilities for llms. In Proceedings of the 2024
849
+ Conference of the North American Chapter of the
850
+ Association for Computational Linguistics: Human
851
+ Language Technologies (Volume 1: Long Papers),
852
+ pages 5522–5532.
853
+
854
+ Run Luo, Ting-En Lin, Haonan Zhang, Yuchuan Wu,
855
+ Xiong Liu, Min Yang, Yongbin Li, Longze Chen,
856
+ Jiaming Li, Lei Zhang, et al. 2025. Openomni:
857
+ Large language models pivot zero-shot omnimodal
858
+ alignment across language with real-time selfaware emotional speech synthesis. arXiv preprint
859
+ arXiv:2501.04561.
860
+
861
+ Alex Graves, Santiago Fernández, Faustino Gomez, and
862
+ Jürgen Schmidhuber. 2006. Connectionist temporal
863
+ classification: Labelling unsegmented sequence data
864
+ with recurrent neural networks. In Proceedings of
865
+ the 23rd International Conference on Machine Learning, ICML ’06, page 369–376, New York, NY, USA.
866
+ Association for Computing Machinery.
867
+
868
+ Ziyang Ma, Yakun Song, Chenpeng Du, Jian Cong,
869
+ Zhuo Chen, Yuping Wang, Yuxuan Wang, and Xie
870
+ Chen. 2024a. Language model can listen while
871
+ speaking. arXiv preprint arXiv:2408.02622.
872
+
873
+ Michael Hassid, Tal Remez, Tu Anh Nguyen, Itai Gat,
874
+ Alexis Conneau, Felix Kreuk, Jade Copet, Alexandre Defossez, Gabriel Synnaeve, Emmanuel Dupoux,
875
+ et al. 2024a. Textually pretrained speech language
876
+ models. Advances in Neural Information Processing
877
+ Systems, 36.
878
+
879
+ Ziyang Ma, Guanrou Yang, Yifan Yang, Zhifu Gao, Jiaming Wang, Zhihao Du, Fan Yu, Qian Chen, Siqi
880
+ Zheng, Shiliang Zhang, et al. 2024b. An embarrassingly simple approach for llm with strong asr capacity.
881
+ arXiv preprint arXiv:2402.08846.
882
+ Fabian Mentzer, David Minnen, Eirikur Agustsson, and
883
+ Michael Tschannen. 2024. Finite scalar quantization:
884
+ VQ-VAE made simple. In The Twelfth International
885
+ Conference on Learning Representations.
886
+
887
+ Michael Hassid, Tal Remez, Tu Anh Nguyen, Itai Gat,
888
+ Alexis Conneau, Felix Kreuk, Jade Copet, Alexandre Defossez, Gabriel Synnaeve, Emmanuel Dupoux,
889
+ et al. 2024b. Textually pretrained speech language
890
+ models. Advances in Neural Information Processing
891
+ Systems, 36.
892
+
893
+ Eliya Nachmani, Alon Levkovitch, Roy Hirsch, Julian Salazar, Chulayuth Asawaroengchai, Soroosh
894
+ Mariooryad, Ehud Rivlin, RJ Skerry-Ryan, and
895
+ Michelle Tadmor Ramanovich. 2024. Spoken
896
+ question answering and speech continuation using
897
+ spectrogram-powered LLM. In The Twelfth International Conference on Learning Representations.
898
+
899
+ Yukiya Hono, Koh Mitsuda, Tianyu Zhao, Kentaro Mitsui, Toshiaki Wakatsuki, and Kei Sawada. 2024. Integrating pre-trained speech and language models for
900
+ end-to-end speech recognition. In Findings of the
901
+ Association for Computational Linguistics ACL 2024,
902
+ pages 13289–13305, Bangkok, Thailand and virtual
903
+ meeting. Association for Computational Linguistics.
904
+
905
+ Tu Anh Nguyen, Benjamin Muller, Bokai Yu, Marta R
906
+ Costa-Jussa, Maha Elbayad, Sravya Popuri, PaulAmbroise Duquenne, Robin Algayres, Ruslan Mavlyutov, Itai Gat, et al. 2024. Spirit-lm: Interleaved
907
+ spoken and written language model. arXiv preprint
908
+ arXiv:2402.05755.
909
+
910
+ Shengpeng Ji, Yifu Chen, Minghui Fang, Jialong Zuo,
911
+ Jingyu Lu, Hanting Wang, Ziyue Jiang, Long Zhou,
912
+ Shujie Liu, Xize Cheng, et al. 2024. Wavchat: A
913
+ survey of spoken dialogue models. arXiv preprint
914
+ arXiv:2411.13577.
915
+
916
+ Open-Moss. 2025.
917
+ Speechgpt 2.0-preview.
918
+ https://github.com/OpenMOSS/SpeechGPT-2.
919
+ 0-preview.
920
+
921
+ Jungil Kong, Jaehyeon Kim, and Jaekyoung Bae. 2020.
922
+ Hifi-gan: Generative adversarial networks for efficient and high fidelity speech synthesis. In Advances in Neural Information Processing Systems,
923
+ volume 33, pages 17022–17033. Curran Associates,
924
+ Inc.
925
+
926
+ OpenAI. 2022. Introducing chatgpt.
927
+ OpenAI. 2024. Hello gpt-4o.
928
+
929
+ Xuechen Li, Tianyi Zhang, Yann Dubois, Rohan Taori,
930
+ Ishaan Gulrajani, Carlos Guestrin, Percy Liang, and
931
+ Tatsunori B. Hashimoto. 2023. Alpacaeval: An automatic evaluator of instruction-following models.
932
+ https://github.com/tatsu-lab/alpaca_eval.
933
+
934
+ Alec Radford. 2018. Improving language understanding
935
+ by generative pre-training.
936
+ Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, and Ilya Sutskever. 2023.
937
+ Robust speech recognition via large-scale weak supervision. In International conference on machine
938
+ learning, pages 28492–28518. PMLR.
939
+
940
+ Shijia Liao, Yuxuan Wang, Tianyu Li, Yifan Cheng,
941
+ Ruoyi Zhang, Rongzhi Zhou, and Yijin Xing. 2024.
942
+ Fish-speech: Leveraging large language models
943
+
944
+ 10
945
+
946
+ Paul K Rubenstein, Chulayuth Asawaroengchai,
947
+ Duc Dung Nguyen, Ankur Bapna, Zalán Borsos,
948
+ Félix de Chaumont Quitry, Peter Chen, Dalia El
949
+ Badawy, Wei Han, Eugene Kharitonov, et al. 2023.
950
+ Audiopalm: A large language model that can speak
951
+ and listen. arXiv preprint arXiv:2306.12925.
952
+
953
+ Aohan Zeng, Zhengxiao Du, Mingdao Liu, Kedong
954
+ Wang, Shengmin Jiang, Lei Zhao, Yuxiao Dong, and
955
+ Jie Tang. 2024a. Glm-4-voice: Towards intelligent
956
+ and human-like end-to-end spoken chatbot. Preprint,
957
+ arXiv:2412.02612.
958
+ Aohan Zeng, Zhengxiao Du, Mingdao Liu, Lei Zhang,
959
+ Shengmin Jiang, Yuxiao Dong, and Jie Tang. 2024b.
960
+ Scaling speech-text pre-training with synthetic interleaved data. Preprint, arXiv:2411.17607.
961
+
962
+ Takaaki Saeki, Detai Xin, Wataru Nakata, Tomoki
963
+ Koriyama, Shinnosuke Takamichi, and Hiroshi
964
+ Saruwatari. 2022. Utmos: Utokyo-sarulab system
965
+ for voicemos challenge 2022. In Interspeech 2022,
966
+ pages 4521–4525.
967
+
968
+ Aohan Zeng, Zhengxiao Du, Mingdao Liu, Lei Zhang,
969
+ shengmin jiang, Yuxiao Dong, and Jie Tang. 2025.
970
+ Scaling speech-text pre-training with synthetic interleaved data. In The Thirteenth International Conference on Learning Representations.
971
+
972
+ Changli Tang, Wenyi Yu, Guangzhi Sun, Xianzhao
973
+ Chen, Tian Tan, Wei Li, Lu Lu, Zejun MA, and Chao
974
+ Zhang. 2024. SALMONN: Towards generic hearing
975
+ abilities for large language models. In The Twelfth
976
+ International Conference on Learning Representations.
977
+
978
+ Dong Zhang, Shimin Li, Xin Zhang, Jun Zhan,
979
+ Pengyu Wang, Yaqian Zhou, and Xipeng Qiu. 2023.
980
+ SpeechGPT: Empowering large language models
981
+ with intrinsic cross-modal conversational abilities.
982
+ In Findings of the Association for Computational
983
+ Linguistics: EMNLP 2023, pages 15757–15773, Singapore. Association for Computational Linguistics.
984
+
985
+ Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann
986
+ Dubois, Xuechen Li, Carlos Guestrin, Percy Liang,
987
+ and Tatsunori B. Hashimoto. 2023. Stanford alpaca:
988
+ An instruction-following llama model. https://
989
+ github.com/tatsu-lab/stanford_alpaca.
990
+
991
+ Dong Zhang, Xin Zhang, Jun Zhan, Shimin Li, Yaqian
992
+ Zhou, and Xipeng Qiu. 2024a. Speechgpt-gen: Scaling chain-of-information speech generation. arXiv
993
+ preprint arXiv:2401.13527.
994
+
995
+ Qwen Team. 2024. Qwen2.5: A party of foundation
996
+ models.
997
+ A Vaswani. 2017. Attention is all you need. Advances
998
+ in Neural Information Processing Systems.
999
+
1000
+ Qinglin Zhang, Luyao Cheng, Chong Deng, Qian Chen,
1001
+ Wen Wang, Siqi Zheng, Jiaqing Liu, Hai Yu, Chaohong Tan, Zhihao Du, et al. 2024b. Omniflatten: An
1002
+ end-to-end gpt model for seamless voice conversation. arXiv preprint arXiv:2410.17799.
1003
+
1004
+ Chen Wang, Minpeng Liao, Zhongqiang Huang, Jinliang Lu, Junhong Wu, Yuchen Liu, Chengqing
1005
+ Zong, and Jiajun Zhang. 2023. Blsp: Bootstrapping language-speech pre-training via behavior
1006
+ alignment of continuation writing. arXiv preprint
1007
+ arXiv:2309.00916.
1008
+
1009
+ Xin Zhang, Xiang Lyu, Zhihao Du, Qian Chen, Dong
1010
+ Zhang, Hangrui Hu, Chaohong Tan, Tianyu Zhao,
1011
+ Yuxuan Wang, Bin Zhang, et al. 2024c. Intrinsicvoice: Empowering llms with intrinsic realtime voice interaction abilities. arXiv preprint
1012
+ arXiv:2410.08035.
1013
+
1014
+ Xiong Wang, Yangze Li, Chaoyou Fu, Yunhang Shen,
1015
+ Lei Xie, Ke Li, Xing Sun, and Long Ma. 2024.
1016
+ Freeze-omni: A smart and low latency speech-tospeech dialogue model with frozen llm. arXiv
1017
+ preprint arXiv:2411.00774.
1018
+ Jian Wu, Yashesh Gaur, Zhuo Chen, Long Zhou, Yimeng Zhu, Tianrui Wang, Jinyu Li, Shujie Liu,
1019
+ Bo Ren, Linquan Liu, et al. 2023. On decoder-only
1020
+ architecture for speech-to-text and large language
1021
+ model integration. In 2023 IEEE Automatic Speech
1022
+ Recognition and Understanding Workshop (ASRU),
1023
+ pages 1–8. IEEE.
1024
+ Zhifei Xie and Changqiao Wu. 2024. Mini-omni: Language models can hear, talk while thinking in streaming. arXiv preprint arXiv:2408.16725.
1025
+ Wenyi Yu, Changli Tang, Guangzhi Sun, Xianzhao
1026
+ Chen, Tian Tan, Wei Li, Lu Lu, Zejun Ma, and Chao
1027
+ Zhang. 2024. Connecting speech encoder and large
1028
+ language model for asr. In ICASSP 2024-2024 IEEE
1029
+ International Conference on Acoustics, Speech and
1030
+ Signal Processing (ICASSP), pages 12637–12641.
1031
+ IEEE.
1032
+
1033
+ 11
1034
+
1035
+ A
1036
+
1037
+ Prompt
1038
+
1039
+ Prompt for ChatGPT Scoring (Model: GPT-4o)
1040
+ I need your help to evaluate the performance of several
1041
+ models in a speech interaction scenario. The models receive the user’s speech input and respond with speech
1042
+ output. For evaluation purposes, both the user’s speech
1043
+ input and the model’s speech response have been transcribed into text using Automatic Speech Recognition
1044
+ (ASR). Your task is to rate the model’s responses based
1045
+ on the provided user input transcription [Instruction] and
1046
+ the model’s output transcription [Response]. Please consider factors such as helpfulness, relevance, fluency, and
1047
+ suitability for speech interaction in your evaluation, and
1048
+ provide a single score on a scale from 1 to 5.
1049
+ Below are the transcription of user’s instruction and models’ response:
1050
+ ### [Instruction]: {instruction}
1051
+ ### [Response]: {response}
1052
+ After evaluating, please output the scores in JSON format:
1053
+ {score: ...}. You don’t need to provide any explanations.
1054
+
1055
+ B
1056
+
1057
+ Detailed Latency
1058
+
1059
+ We list the detailed latency at different stages of
1060
+ the model in Table 6. “LLM” refers to the latency
1061
+ for generating the first R text tokens, “TTS” refers
1062
+ to the latency for generating the first W speech
1063
+ tokens, and “FM+Voc” refers to the latency for
1064
+ generating the first speech chunk using the flow
1065
+ matching model and vocoder.
CLAUDE.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!-- BACKLOG.MD GUIDELINES START -->
3
+ # Instructions for the usage of Backlog.md CLI Tool
4
+
5
+ ## 1. Source of Truth
6
+
7
+ - Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
8
+ - Every implementation decision starts with reading the corresponding Markdown task file.
9
+ - Project documentation is in **`backlog/docs/`**.
10
+ - Project decisions are in **`backlog/decisions/`**.
11
+
12
+ ## 2. Defining Tasks
13
+
14
+ ### Understand the Scope and the purpose
15
+
16
+ Ask questions to the user if something is not clear or ambiguous.
17
+ Break down the task into smaller, manageable parts if it is too large or complex.
18
+
19
+ ### **Title (one liner)**
20
+
21
+ Use a clear brief title that summarizes the task.
22
+
23
+ ### **Description**: (The **"why"**)
24
+
25
+ Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
26
+ should explain the purpose and context of the task. Code snippets should be avoided.
27
+
28
+ ### **Acceptance Criteria**: (The **"what"**)
29
+
30
+ List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
31
+ `- [ ]`) for tracking.
32
+ When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
33
+ than step-by-step implementation details.
34
+ Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
35
+ They should be testable and confirm that the core purpose of the task is achieved.
36
+ **Key Principles for Good ACs:**
37
+
38
+ - **Outcome-Oriented:** Focus on the result, not the method.
39
+ - **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
40
+ - **Clear and Concise:** Unambiguous language.
41
+ - **Complete:** Collectively, ACs should cover the scope of the task.
42
+ - **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
43
+
44
+ - *Good Example:* "- [ ] User can successfully log in with valid credentials."
45
+ - *Good Example:* "- [ ] System processes 1000 requests per second without errors."
46
+ - *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
47
+
48
+ ### Task file
49
+
50
+ Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
51
+ `task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
52
+
53
+ ### Task Breakdown Strategy
54
+
55
+ When breaking down features:
56
+
57
+ 1. Identify the foundational components first
58
+ 2. Create tasks in dependency order (foundations before features)
59
+ 3. Ensure each task delivers value independently
60
+ 4. Avoid creating tasks that block each other
61
+
62
+ ### Additional task requirements
63
+
64
+ - Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
65
+ Each task should represent a single unit of work that can be completed in a single PR.
66
+
67
+ - **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
68
+ previous
69
+ tasks (id < current task id).
70
+
71
+ - When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
72
+ Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
73
+ schema".
74
+ Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
75
+ schema", task 3: "Add API endpoint for user data".
76
+
77
+ ## 3. Recommended Task Anatomy
78
+
79
+ ```markdown
80
+ # task‑42 - Add GraphQL resolver
81
+
82
+ ## Description (the why)
83
+
84
+ Short, imperative explanation of the goal of the task and why it is needed.
85
+
86
+ ## Acceptance Criteria (the what)
87
+
88
+ - [ ] Resolver returns correct data for happy path
89
+ - [ ] Error response matches REST
90
+ - [ ] P95 latency ≤ 50 ms under 100 RPS
91
+
92
+ ## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
93
+
94
+ 1. Research existing GraphQL resolver patterns
95
+ 2. Implement basic resolver with error handling
96
+ 3. Add performance monitoring
97
+ 4. Write unit and integration tests
98
+ 5. Benchmark performance under load
99
+
100
+ ## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
101
+
102
+ - Approach taken
103
+ - Features implemented or modified
104
+ - Technical decisions and trade-offs
105
+ - Modified or added files
106
+ ```
107
+
108
+ ## 6. Implementing Tasks
109
+
110
+ Mandatory sections for every task:
111
+
112
+ - **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
113
+ change after the task is created, **the implementation plan must be added only after putting the task in progress**
114
+ and before starting working on the task.
115
+ - **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
116
+ section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
117
+ concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
118
+ others will read the code to understand the details.
119
+
120
+ **IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
121
+ implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
122
+
123
+ ## 2. Typical Workflow
124
+
125
+ ```bash
126
+ # 1 Identify work
127
+ backlog task list -s "To Do" --plain
128
+
129
+ # 2 Read details & documentation
130
+ backlog task 42 --plain
131
+ # Read also all documentation files in `backlog/docs/` directory.
132
+ # Read also all decision files in `backlog/decisions/` directory.
133
+
134
+ # 3 Start work: assign yourself & move column
135
+ backlog task edit 42 -a @{yourself} -s "In Progress"
136
+
137
+ # 4 Add implementation plan before starting
138
+ backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
139
+
140
+ # 5 Break work down if needed by creating subtasks or additional tasks
141
+ backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
142
+
143
+ # 6 Complete and mark Done
144
+ backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
145
+ ```
146
+
147
+ ### 7. Final Steps Before Marking a Task as Done
148
+
149
+ Always ensure you have:
150
+
151
+ 1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
152
+ 2. ✅ Added an `## Implementation Notes` section documenting your approach
153
+ 3. ✅ Run all tests and linting checks
154
+ 4. ✅ Updated relevant documentation
155
+
156
+ ## 8. Definition of Done (DoD)
157
+
158
+ A task is **Done** only when **ALL** of the following are complete:
159
+
160
+ 1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
161
+ 2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
162
+ 3. **Automated tests** (unit + integration) cover new logic.
163
+ 4. **Static analysis**: linter & formatter succeed.
164
+ 5. **Documentation**:
165
+ - All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
166
+ - Task file **MUST** have an `## Implementation Notes` section added summarising:
167
+ - Approach taken
168
+ - Features implemented or modified
169
+ - Technical decisions and trade-offs
170
+ - Modified or added files
171
+ 6. **Review**: self review code.
172
+ 7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
173
+ 8. **No regressions**: performance, security and licence checks green.
174
+
175
+ ⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
176
+
177
+ ## 9. Handy CLI Commands
178
+
179
+ | Action | Example |
180
+ |-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
181
+ | Create task | `backlog task create "Add OAuth System"` |
182
+ | Create with description | `backlog task create "Feature" -d "Add authentication system"` |
183
+ | Create with assignee | `backlog task create "Feature" -a @sara` |
184
+ | Create with status | `backlog task create "Feature" -s "In Progress"` |
185
+ | Create with labels | `backlog task create "Feature" -l auth,backend` |
186
+ | Create with priority | `backlog task create "Feature" --priority high` |
187
+ | Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
188
+ | Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
189
+ | Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
190
+ | Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
191
+ | Create sub task | `backlog task create -p 14 "Add Login with Google"` |
192
+ | Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
193
+ | List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
194
+ | List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
195
+ | View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
196
+ | View (AI mode) | `backlog task 7 --plain` |
197
+ | Edit | `backlog task edit 7 -a @sara -l auth,backend` |
198
+ | Add plan | `backlog task edit 7 --plan "Implementation approach"` |
199
+ | Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
200
+ | Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
201
+ | Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
202
+ | Archive | `backlog task archive 7` |
203
+ | Create draft | `backlog task create "Feature" --draft` |
204
+ | Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
205
+ | Demote to draft | `backlog task demote <id>` |
206
+
207
+ Full help: `backlog --help`
208
+
209
+ ## 10. Tips for AI Agents
210
+
211
+ - **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
212
+ interactive UI.
213
+ - When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
214
+
215
+ <!-- BACKLOG.MD GUIDELINES END -->
COSYVOICE2_CHANGES.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CosyVoice2 Model Changes Documentation
2
+
3
+ ## Overview
4
+ This document captures the modifications made to the CosyVoice2 model integration for the LLaMA-Omni2 voice assistant system.
5
+
6
+ ## Key Changes
7
+
8
+ ### 1. Configuration Files
9
+ - **cosyvoice.yaml**: Primary configuration file used by the model
10
+ - **cosyvoice2.yaml**: Original CosyVoice2 configuration
11
+ - **cosyvoice_fixed.yaml**: Configuration with `mix_ratio` parameter removed to fix compatibility issues
12
+
13
+ ### 2. Model Files Structure
14
+ ```
15
+ models/cosyvoice2/
16
+ ├── CosyVoice-BlankEN/ # English tokenizer model
17
+ ├── campplus.onnx # Speaker embedding model
18
+ ├── flow.decoder.estimator.fp32.onnx # Flow decoder
19
+ ├── flow.pt # Flow model weights
20
+ ├── hift.pt # HiFi-GAN vocoder weights
21
+ ├── llm.pt # Language model weights
22
+ ├── speech_tokenizer_v1.onnx # Speech tokenizer v1
23
+ └── speech_tokenizer_v2.onnx # Speech tokenizer v2 (new addition)
24
+ ```
25
+
26
+ ### 3. Code Modifications
27
+
28
+ #### cosyvoice/flow/flow.py
29
+ - Modified to handle CosyVoice2 model architecture
30
+ - Updated MaskedDiffWithXvec class for compatibility
31
+ - Adjusted decoder configuration parameters
32
+
33
+ #### llama_omni2/serve/flow_inference.py
34
+ - Updated SpeechDecoder class to properly load CosyVoice2 models
35
+ - Changed configuration loading to use 'cosyvoice.yaml' instead of fallback logic
36
+ - Added support for speech_tokenizer_v2.onnx
37
+
38
+ ### 4. Integration Points
39
+ - **Model Path**: `models/cosyvoice2/` or `models/cosy2_decoder/`
40
+ - **Frontend**: CosyVoiceFrontEnd handles tokenization and feature extraction
41
+ - **Vocoder**: Uses the model as vocoder in gradio_web_server.py with `--vocoder-dir` flag
42
+
43
+ ## Setup Requirements
44
+
45
+ ### Model Download
46
+ ```bash
47
+ # Download CosyVoice2 model from HuggingFace
48
+ python -c "
49
+ from huggingface_hub import snapshot_download
50
+ snapshot_download(
51
+ repo_id='FunAudioLLM/CosyVoice2-0.5B',
52
+ local_dir='models/cosyvoice2',
53
+ local_dir_use_symlinks=False
54
+ )"
55
+ ```
56
+
57
+ ### Configuration Fix
58
+ The original cosyvoice2.yaml may contain a `mix_ratio` parameter that causes issues. This is fixed by:
59
+ 1. Copying cosyvoice2.yaml to cosyvoice.yaml
60
+ 2. Removing the mix_ratio parameter
61
+
62
+ ```bash
63
+ cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice.yaml
64
+ grep -v "mix_ratio" models/cosyvoice2/cosyvoice.yaml > models/cosyvoice2/cosyvoice_fixed.yaml
65
+ mv models/cosyvoice2/cosyvoice_fixed.yaml models/cosyvoice2/cosyvoice.yaml
66
+ ```
67
+
68
+ ## Usage in LLaMA-Omni2
69
+
70
+ Start the Gradio server with CosyVoice2 as vocoder:
71
+ ```bash
72
+ python -m llama_omni2.serve.gradio_web_server \
73
+ --controller http://localhost:10000 \
74
+ --port 8000 \
75
+ --vocoder-dir models/cosyvoice2
76
+ ```
77
+
78
+ ## Known Issues and Solutions
79
+
80
+ 1. **mix_ratio parameter error**: Remove from configuration file
81
+ 2. **Missing cosyvoice.yaml**: Copy from cosyvoice2.yaml
82
+ 3. **Tokenizer compatibility**: Ensure both v1 and v2 tokenizers are present
83
+
84
+ ## Performance Notes
85
+ - CosyVoice2-0.5B is optimized for faster inference
86
+ - Supports both Chinese and English text-to-speech
87
+ - Compatible with streaming generation for real-time applications
GEMINI.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!-- BACKLOG.MD GUIDELINES START -->
3
+ # Instructions for the usage of Backlog.md CLI Tool
4
+
5
+ ## 1. Source of Truth
6
+
7
+ - Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
8
+ - Every implementation decision starts with reading the corresponding Markdown task file.
9
+ - Project documentation is in **`backlog/docs/`**.
10
+ - Project decisions are in **`backlog/decisions/`**.
11
+
12
+ ## 2. Defining Tasks
13
+
14
+ ### Understand the Scope and the purpose
15
+
16
+ Ask questions to the user if something is not clear or ambiguous.
17
+ Break down the task into smaller, manageable parts if it is too large or complex.
18
+
19
+ ### **Title (one liner)**
20
+
21
+ Use a clear brief title that summarizes the task.
22
+
23
+ ### **Description**: (The **"why"**)
24
+
25
+ Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
26
+ should explain the purpose and context of the task. Code snippets should be avoided.
27
+
28
+ ### **Acceptance Criteria**: (The **"what"**)
29
+
30
+ List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
31
+ `- [ ]`) for tracking.
32
+ When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
33
+ than step-by-step implementation details.
34
+ Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
35
+ They should be testable and confirm that the core purpose of the task is achieved.
36
+ **Key Principles for Good ACs:**
37
+
38
+ - **Outcome-Oriented:** Focus on the result, not the method.
39
+ - **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
40
+ - **Clear and Concise:** Unambiguous language.
41
+ - **Complete:** Collectively, ACs should cover the scope of the task.
42
+ - **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
43
+
44
+ - *Good Example:* "- [ ] User can successfully log in with valid credentials."
45
+ - *Good Example:* "- [ ] System processes 1000 requests per second without errors."
46
+ - *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
47
+
48
+ ### Task file
49
+
50
+ Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
51
+ `task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
52
+
53
+ ### Task Breakdown Strategy
54
+
55
+ When breaking down features:
56
+
57
+ 1. Identify the foundational components first
58
+ 2. Create tasks in dependency order (foundations before features)
59
+ 3. Ensure each task delivers value independently
60
+ 4. Avoid creating tasks that block each other
61
+
62
+ ### Additional task requirements
63
+
64
+ - Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
65
+ Each task should represent a single unit of work that can be completed in a single PR.
66
+
67
+ - **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
68
+ previous
69
+ tasks (id < current task id).
70
+
71
+ - When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
72
+ Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
73
+ schema".
74
+ Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
75
+ schema", task 3: "Add API endpoint for user data".
76
+
77
+ ## 3. Recommended Task Anatomy
78
+
79
+ ```markdown
80
+ # task‑42 - Add GraphQL resolver
81
+
82
+ ## Description (the why)
83
+
84
+ Short, imperative explanation of the goal of the task and why it is needed.
85
+
86
+ ## Acceptance Criteria (the what)
87
+
88
+ - [ ] Resolver returns correct data for happy path
89
+ - [ ] Error response matches REST
90
+ - [ ] P95 latency ≤ 50 ms under 100 RPS
91
+
92
+ ## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
93
+
94
+ 1. Research existing GraphQL resolver patterns
95
+ 2. Implement basic resolver with error handling
96
+ 3. Add performance monitoring
97
+ 4. Write unit and integration tests
98
+ 5. Benchmark performance under load
99
+
100
+ ## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
101
+
102
+ - Approach taken
103
+ - Features implemented or modified
104
+ - Technical decisions and trade-offs
105
+ - Modified or added files
106
+ ```
107
+
108
+ ## 6. Implementing Tasks
109
+
110
+ Mandatory sections for every task:
111
+
112
+ - **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
113
+ change after the task is created, **the implementation plan must be added only after putting the task in progress**
114
+ and before starting working on the task.
115
+ - **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
116
+ section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
117
+ concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
118
+ others will read the code to understand the details.
119
+
120
+ **IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
121
+ implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
122
+
123
+ ## 2. Typical Workflow
124
+
125
+ ```bash
126
+ # 1 Identify work
127
+ backlog task list -s "To Do" --plain
128
+
129
+ # 2 Read details & documentation
130
+ backlog task 42 --plain
131
+ # Read also all documentation files in `backlog/docs/` directory.
132
+ # Read also all decision files in `backlog/decisions/` directory.
133
+
134
+ # 3 Start work: assign yourself & move column
135
+ backlog task edit 42 -a @{yourself} -s "In Progress"
136
+
137
+ # 4 Add implementation plan before starting
138
+ backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
139
+
140
+ # 5 Break work down if needed by creating subtasks or additional tasks
141
+ backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
142
+
143
+ # 6 Complete and mark Done
144
+ backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
145
+ ```
146
+
147
+ ### 7. Final Steps Before Marking a Task as Done
148
+
149
+ Always ensure you have:
150
+
151
+ 1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
152
+ 2. ✅ Added an `## Implementation Notes` section documenting your approach
153
+ 3. ✅ Run all tests and linting checks
154
+ 4. ✅ Updated relevant documentation
155
+
156
+ ## 8. Definition of Done (DoD)
157
+
158
+ A task is **Done** only when **ALL** of the following are complete:
159
+
160
+ 1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
161
+ 2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
162
+ 3. **Automated tests** (unit + integration) cover new logic.
163
+ 4. **Static analysis**: linter & formatter succeed.
164
+ 5. **Documentation**:
165
+ - All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
166
+ - Task file **MUST** have an `## Implementation Notes` section added summarising:
167
+ - Approach taken
168
+ - Features implemented or modified
169
+ - Technical decisions and trade-offs
170
+ - Modified or added files
171
+ 6. **Review**: self review code.
172
+ 7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
173
+ 8. **No regressions**: performance, security and licence checks green.
174
+
175
+ ⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
176
+
177
+ ## 9. Handy CLI Commands
178
+
179
+ | Action | Example |
180
+ |-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
181
+ | Create task | `backlog task create "Add OAuth System"` |
182
+ | Create with description | `backlog task create "Feature" -d "Add authentication system"` |
183
+ | Create with assignee | `backlog task create "Feature" -a @sara` |
184
+ | Create with status | `backlog task create "Feature" -s "In Progress"` |
185
+ | Create with labels | `backlog task create "Feature" -l auth,backend` |
186
+ | Create with priority | `backlog task create "Feature" --priority high` |
187
+ | Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
188
+ | Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
189
+ | Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
190
+ | Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
191
+ | Create sub task | `backlog task create -p 14 "Add Login with Google"` |
192
+ | Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
193
+ | List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
194
+ | List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
195
+ | View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
196
+ | View (AI mode) | `backlog task 7 --plain` |
197
+ | Edit | `backlog task edit 7 -a @sara -l auth,backend` |
198
+ | Add plan | `backlog task edit 7 --plan "Implementation approach"` |
199
+ | Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
200
+ | Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
201
+ | Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
202
+ | Archive | `backlog task archive 7` |
203
+ | Create draft | `backlog task create "Feature" --draft` |
204
+ | Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
205
+ | Demote to draft | `backlog task demote <id>` |
206
+
207
+ Full help: `backlog --help`
208
+
209
+ ## 10. Tips for AI Agents
210
+
211
+ - **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
212
+ interactive UI.
213
+ - When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
214
+
215
+ <!-- BACKLOG.MD GUIDELINES END -->
LLaMA-Omni2-3B/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🦙🎧 LLaMA-Omni 2: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis
2
+
3
+ > **Authors: [Qingkai Fang](https://fangqingkai.github.io/), [Yan Zhou](https://zhouyan19.github.io/zhouyan/), [Shoutao Guo](https://scholar.google.com/citations?hl=en&user=XwHtPyAAAAAJ), [Shaolei Zhang](https://zhangshaolei1998.github.io/), [Yang Feng*](https://people.ucas.edu.cn/~yangfeng?language=en)**
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2505.02625-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2505.02625)
6
+ [![code](https://img.shields.io/badge/Github-Code-keygen.svg?logo=github)](https://github.com/ictnlp/LLaMA-Omni2)
7
+ [![models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging_Face-Models-blue.svg)](https://huggingface.co/collections/ICTNLP/llama-omni-67fdfb852c60470175e36e9c)
8
+ [![dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging_Face-Dataset-blue.svg)](https://huggingface.co/datasets/ICTNLP/Multiturn-Speech-Conversations)
9
+
10
+ LLaMA-Omni 2 is a series of speech-language models built on the Qwen2.5-0.5B/1.5B/3B/7B/14B/32B-Instruct models. Similar to [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni), it can generate both text and speech responses simultaneously, enabling high-quality and low-latency speech interaction. With the newly introduced streaming autoregressive speech decoder, LLaMA-Omni 2 achieves higher speech quality compared to LLaMA-Omni.
11
+
12
+ <div align="center"><img src="images/llama-omni2.png" width="75%"/></div>
13
+
14
+ ## 🔥 News
15
+
16
+ - [25/05] LLaMA-Omni 2 is accepted at ACL 2025 main conference!
17
+
18
+ ## Install
19
+
20
+ 1. Clone this repository.
21
+
22
+ ```shell
23
+ git clone https://github.com/ictnlp/LLaMA-Omni2
24
+ cd LLaMA-Omni2
25
+ ```
26
+
27
+ 2. Install packages.
28
+
29
+ ```shell
30
+ conda create -n llama-omni2 python=3.10
31
+ conda activate llama-omni2
32
+ pip install -e .
33
+ ```
34
+
35
+ ## Quick Start
36
+
37
+ 1. Download the `Whisper-large-v3` model.
38
+
39
+ ```shell
40
+ import whisper
41
+ model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
42
+ ```
43
+
44
+ 2. Download the flow-matching model and vocoder of `CosyVoice 2`.
45
+
46
+ ```shell
47
+ huggingface-cli download --resume-download ICTNLP/cosy2_decoder --local-dir models/cosy2_decoder
48
+ ```
49
+
50
+ > [!Tip]
51
+ > If you’re experiencing unstable connections to Hugging Face from within China, you can try setting the following in your command line:
52
+ >
53
+ > ```shell
54
+ > export HF_ENDPOINT=https://hf-mirror.com
55
+ > ```
56
+
57
+ 3. Download the LLaMA-Omni2 series models from Hugging Face. `LLaMA-Omni2-0.5B/1.5B/3B/7B/14B` support **English only**, while `LLaMA-Omni2-0.5B/1.5B/3B/7B/14B/32B-Bilingual` support **both English and Chinese**.
58
+
59
+ ```shell
60
+ model_name=LLaMA-Omni2-7B-Bilingual
61
+ huggingface-cli download --resume-download ICTNLP/$model_name --local-dir models/$model_name
62
+ ```
63
+
64
+ | LLaMA-Omni2 | LLaMA-Omni2-Bilingual |
65
+ | --------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |
66
+ | 🤗 [LLaMA-Omni2-0.5B](https://huggingface.co/ICTNLP/LLaMA-Omni2-0.5B) | 🤗 [LLaMA-Omni2-0.5B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-0.5B-Bilingual) |
67
+ | 🤗 [LLaMA-Omni2-1.5B](https://huggingface.co/ICTNLP/LLaMA-Omni2-1.5B) | 🤗 [LLaMA-Omni2-1.5B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-1.5B-Bilingual) |
68
+ | 🤗 [LLaMA-Omni2-3B](https://huggingface.co/ICTNLP/LLaMA-Omni2-3B) | 🤗 [LLaMA-Omni2-3B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-3B-Bilingual) |
69
+ | 🤗 [LLaMA-Omni2-7B](https://huggingface.co/ICTNLP/LLaMA-Omni2-7B) | 🤗 [LLaMA-Omni2-7B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-7B-Bilingual) |
70
+ | 🤗 [LLaMA-Omni2-14B](https://huggingface.co/ICTNLP/LLaMA-Omni2-14B) | 🤗 [LLaMA-Omni2-14B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-14B-Bilingual) |
71
+ | - | 🤗 [LLaMA-Omni2-32B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-32B-Bilingual) |
72
+
73
+ ## Gradio Demo
74
+
75
+ 1. Launch a controller.
76
+
77
+ ```shell
78
+ python -m llama_omni2.serve.controller --host 0.0.0.0 --port 10000
79
+ ```
80
+
81
+ 2. Launch a gradio web server.
82
+
83
+ ```shell
84
+ python -m llama_omni2.serve.gradio_web_server --controller http://localhost:10000 --port 8000 --vocoder-dir models/cosy2_decoder
85
+ ```
86
+
87
+ 3. Launch a model worker.
88
+
89
+ ```shell
90
+ python -m llama_omni2.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path models/$model_name --model-name $model_name
91
+ ```
92
+
93
+ 4. Visit [http://localhost:8000/](http://localhost:8000/) and interact with LLaMA-Omni2!
94
+
95
+ ## Local Inference
96
+
97
+ ```shell
98
+ output_dir=examples/$model_name
99
+ mkdir -p $output_dir
100
+
101
+ python llama_omni2/inference/run_llama_omni2.py \
102
+ --model_path models/$model_name \
103
+ --question_file examples/questions.json \
104
+ --answer_file $output_dir/answers.jsonl \
105
+ --temperature 0 \
106
+ --s2s
107
+
108
+ python llama_omni2/inference/run_cosy2_decoder.py \
109
+ --input-path $output_dir/answers.jsonl \
110
+ --output-dir $output_dir/wav \
111
+ --lang en
112
+ ```
113
+
114
+ ## LICENSE
115
+
116
+ Our code is released under the Apache-2.0 License. Our model is intended for academic research purposes only and may **NOT** be used for commercial purposes.
117
+
118
+ You are free to use, modify, and distribute this model in academic settings, provided that the following conditions are met:
119
+
120
+ - **Non-commercial use**: The model may not be used for any commercial purposes.
121
+ - **Citation**: If you use this model in your research, please cite the original work.
122
+
123
+ ### Commercial Use Restriction
124
+
125
+ For any commercial use inquiries or to obtain a commercial license, please contact `fengyang@ict.ac.cn`.
126
+
127
+ ## Acknowledgements
128
+
129
+ - [CosyVoice 2](https://github.com/FunAudioLLM/CosyVoice): We use the pretrained speech tokenizer, flow-matching model and vocoder of CosyVoice 2.
130
+ - [SLAM-LLM](https://github.com/X-LANCE/SLAM-LLM): We borrow some code about speech encoder and speech adaptor.
131
+
132
+ ## Citation
133
+
134
+ If you have any questions, please feel free to submit an issue or contact `fangqingkai21b@ict.ac.cn`.
135
+
136
+ If our work is useful for you, please cite as:
137
+
138
+ ```
139
+ @inproceedings{
140
+ fang2025llamaomni2,
141
+ title={{LL}a{MA}-{O}mni 2: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis},
142
+ author={Fang, Qingkai and Zhou, Yan and Guo, Shoutao and Zhang, Shaolei and Feng, Yang},
143
+ booktitle = {Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics},
144
+ year={2025}
145
+ }
146
+
147
+ @inproceedings{
148
+ fang2025llamaomni,
149
+ title={{LL}a{MA}-{O}mni: Seamless Speech Interaction with Large Language Models},
150
+ author={Qingkai Fang and Shoutao Guo and Yan Zhou and Zhengrui Ma and Shaolei Zhang and Yang Feng},
151
+ booktitle={The Thirteenth International Conference on Learning Representations},
152
+ year={2025},
153
+ url={https://openreview.net/forum?id=PYmrUQmMEw}
154
+ }
155
+ ```
LLaMA-Omni2-3B/added_tokens.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<speech>": 151665,
4
+ "<tool_call>": 151657,
5
+ "<|box_end|>": 151649,
6
+ "<|box_start|>": 151648,
7
+ "<|endoftext|>": 151643,
8
+ "<|file_sep|>": 151664,
9
+ "<|fim_middle|>": 151660,
10
+ "<|fim_pad|>": 151662,
11
+ "<|fim_prefix|>": 151659,
12
+ "<|fim_suffix|>": 151661,
13
+ "<|im_end|>": 151645,
14
+ "<|im_start|>": 151644,
15
+ "<|image_pad|>": 151655,
16
+ "<|object_ref_end|>": 151647,
17
+ "<|object_ref_start|>": 151646,
18
+ "<|quad_end|>": 151651,
19
+ "<|quad_start|>": 151650,
20
+ "<|repo_name|>": 151663,
21
+ "<|video_pad|>": 151656,
22
+ "<|vision_end|>": 151653,
23
+ "<|vision_pad|>": 151654,
24
+ "<|vision_start|>": 151652
25
+ }
LLaMA-Omni2-3B/config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LLaMA-Omni2-3B",
3
+ "architectures": [
4
+ "Omni2Speech2SQwen2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 11008,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 70,
15
+ "model_type": "omni2_speech2s_qwen2",
16
+ "num_attention_heads": 16,
17
+ "num_hidden_layers": 36,
18
+ "num_key_value_heads": 2,
19
+ "rms_norm_eps": 1e-06,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": null,
22
+ "speech_encoder": "models/speech_encoder/large-v3.pt",
23
+ "speech_encoder_ds_rate": 5,
24
+ "speech_encoder_hidden_size": 1280,
25
+ "speech_encoder_type": "whisper",
26
+ "speech_generator": {
27
+ "architectures": [
28
+ "Qwen2ForCausalLM"
29
+ ],
30
+ "attention_dropout": 0.0,
31
+ "bos_token_id": 151643,
32
+ "eos_token_id": 151643,
33
+ "hidden_act": "silu",
34
+ "hidden_size": 896,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 4864,
37
+ "max_position_embeddings": 32768,
38
+ "max_window_layers": 24,
39
+ "model_type": "qwen2",
40
+ "num_attention_heads": 14,
41
+ "num_hidden_layers": 24,
42
+ "num_key_value_heads": 2,
43
+ "rms_norm_eps": 1e-06,
44
+ "rope_theta": 1000000.0,
45
+ "sliding_window": null,
46
+ "tie_word_embeddings": true,
47
+ "torch_dtype": "bfloat16",
48
+ "transformers_version": "4.43.4",
49
+ "use_cache": true,
50
+ "use_mrope": false,
51
+ "use_sliding_window": false,
52
+ "vocab_size": 158227
53
+ },
54
+ "speech_projector_type": "linear",
55
+ "stream_params": "(3,10)",
56
+ "tie_word_embeddings": true,
57
+ "tokenizer_model_max_length": 4096,
58
+ "tokenizer_padding_side": "right",
59
+ "torch_dtype": "bfloat16",
60
+ "transformers_version": "4.43.4",
61
+ "unit_vocab_size": 6561,
62
+ "use_cache": true,
63
+ "use_sliding_window": false,
64
+ "vocab_size": 151936
65
+ }
LLaMA-Omni2-3B/generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_implementation": "flash_attention_2",
3
+ "bos_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "pad_token_id": 151643,
10
+ "repetition_penalty": 1.05,
11
+ "temperature": 0.7,
12
+ "top_k": 20,
13
+ "top_p": 0.8,
14
+ "transformers_version": "4.43.4"
15
+ }
LLaMA-Omni2-3B/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc4b7fda5d470353f675e0410724af00479eb09b3c81c5648bad35ac97904665
3
+ size 4957560304
LLaMA-Omni2-3B/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60dd465c6e6ceac492af19fbdbe11dd9fb4104b2a71c3825fba38a8a0427ed94
3
+ size 4455567096
LLaMA-Omni2-3B/model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/special_tokens_map.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": "<|endoftext|>"
25
+ }
LLaMA-Omni2-3B/tokenizer_config.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<speech>",
183
+ "lstrip": false,
184
+ "normalized": true,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ }
189
+ },
190
+ "additional_special_tokens": [
191
+ "<|im_start|>",
192
+ "<|im_end|>",
193
+ "<|object_ref_start|>",
194
+ "<|object_ref_end|>",
195
+ "<|box_start|>",
196
+ "<|box_end|>",
197
+ "<|quad_start|>",
198
+ "<|quad_end|>",
199
+ "<|vision_start|>",
200
+ "<|vision_end|>",
201
+ "<|vision_pad|>",
202
+ "<|image_pad|>",
203
+ "<|video_pad|>"
204
+ ],
205
+ "bos_token": null,
206
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
207
+ "clean_up_tokenization_spaces": false,
208
+ "eos_token": "<|im_end|>",
209
+ "errors": "replace",
210
+ "model_max_length": 4096,
211
+ "pad_token": "<|endoftext|>",
212
+ "padding_side": "right",
213
+ "split_special_tokens": false,
214
+ "tokenizer_class": "Qwen2Tokenizer",
215
+ "unk_token": null
216
+ }
LLaMA-Omni2-3B/tts_tokenizer/added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/tts_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/tts_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": "<|endoftext|>"
25
+ }
LLaMA-Omni2-3B/tts_tokenizer/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/tts_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
LLaMA-Omni2-3B/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎙️🤖 Goodspace Voice Agent: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis
2
+
3
+ > **Powered by advanced speech-language models and streaming synthesis technology**
4
+
5
+ [![code](https://img.shields.io/badge/Github-Code-keygen.svg?logo=github)](https://github.com/goodspace/voice-agent)
6
+ [![models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging_Face-Models-blue.svg)](https://huggingface.co/collections/goodspace/voice-agent)
7
+ [![dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging_Face-Dataset-blue.svg)](https://huggingface.co/datasets/goodspace/speech-conversations)
8
+
9
+ Goodspace Voice Agent is a cutting-edge series of speech-language models built on the Qwen2.5-0.5B/1.5B/3B/7B/14B/32B-Instruct models. It can generate both text and speech responses simultaneously, enabling high-quality and low-latency speech interaction. With the streaming autoregressive speech decoder, Goodspace Voice Agent achieves exceptional speech quality and natural conversation flow.
10
+
11
+ <div align="center"><img src="images/llama-omni2.png" width="75%"/></div>
12
+
13
+
14
+ ## 🔥 News
15
+
16
+ - Goodspace Voice Agent - Advanced real-time voice interaction system now available!
17
+
18
+ ## Install
19
+
20
+ 1. Clone this repository.
21
+
22
+ ```shell
23
+ git clone https://github.com/goodspace/voice-agent
24
+ cd voice-agent
25
+ ```
26
+
27
+ 2. Install packages.
28
+
29
+ ```shell
30
+ conda create -n goodspace-voice python=3.10
31
+ conda activate goodspace-voice
32
+ pip install -e .
33
+ ```
34
+
35
+ ## Quick Start
36
+
37
+ 1. Download the `Whisper-large-v3` model.
38
+
39
+ ```shell
40
+ import whisper
41
+ model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
42
+ ```
43
+
44
+ 2. Download the flow-matching model and vocoder of `CosyVoice 2`.
45
+
46
+ ```shell
47
+ huggingface-cli download --resume-download goodspace/cosy2_decoder --local-dir models/cosy2_decoder
48
+ ```
49
+
50
+ > [!Tip]
51
+ > If you’re experiencing unstable connections to Hugging Face from within China, you can try setting the following in your command line:
52
+ >
53
+ > ```shell
54
+ > export HF_ENDPOINT=https://hf-mirror.com
55
+ > ```
56
+
57
+ 3. Download the Goodspace Voice Agent models from Hugging Face. `GoodspaceVoice-0.5B/1.5B/3B/7B/14B` support **English only**, while `GoodspaceVoice-0.5B/1.5B/3B/7B/14B/32B-Bilingual` support **both English and Chinese**.
58
+
59
+ ```shell
60
+ model_name=GoodspaceVoice-7B-Bilingual
61
+ huggingface-cli download --resume-download goodspace/$model_name --local-dir models/$model_name
62
+ ```
63
+ ## Gradio Demo
64
+
65
+ 1. Launch a controller.
66
+
67
+ ```shell
68
+ python -m goodspace_voice.serve.controller --host 0.0.0.0 --port 10000
69
+ ```
70
+
71
+ 2. Launch a gradio web server.
72
+
73
+ ```shell
74
+ python -m goodspace_voice.serve.gradio_web_server --controller http://localhost:10000 --port 8000 --vocoder-dir models/cosy2_decoder
75
+ ```
76
+
77
+ 3. Launch a model worker.
78
+
79
+ ```shell
80
+ python -m goodspace_voice.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path models/$model_name --model-name $model_name
81
+ ```
82
+
83
+ 4. Visit [http://localhost:8000/](http://localhost:8000/) and interact with GoodspaceVoice!
84
+
85
+ ## Local Inference
86
+
87
+ ```shell
88
+ output_dir=examples/$model_name
89
+ mkdir -p $output_dir
90
+
91
+ python goodspace_voice/inference/run_goodspace_voice.py \
92
+ --model_path models/$model_name \
93
+ --question_file examples/questions.json \
94
+ --answer_file $output_dir/answers.jsonl \
95
+ --temperature 0 \
96
+ --s2s
97
+
98
+ python goodspace_voice/inference/run_cosy2_decoder.py \
99
+ --input-path $output_dir/answers.jsonl \
100
+ --output-dir $output_dir/wav \
101
+ --lang en
102
+ ```
103
+
104
+ ## LICENSE
105
+
106
+ The Goodspace Voice Agent is released under the Apache-2.0 License.
107
+
108
+ ### Commercial Use
109
+
110
+ For commercial use inquiries or licensing information, please contact the Goodspace team.
111
+
112
+ ## Acknowledgements
113
+
114
+ - [CosyVoice 2](https://github.com/FunAudioLLM/CosyVoice): We use the pretrained speech tokenizer, flow-matching model and vocoder of CosyVoice 2.
115
+ - [SLAM-LLM](https://github.com/X-LANCE/SLAM-LLM): We borrow some code about speech encoder and speech adaptor.
116
+ - Based on the research work from LLaMA-Omni2 paper.
117
+
118
+ ## Support
119
+
120
+ If you have any questions or issues, please feel free to submit an issue on our GitHub repository.
121
+
122
+ ## Contributing
123
+
124
+ We welcome contributions! Please see our contributing guidelines for more information.
SETUP_GUIDE.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaMA-Omni2 Voice Assistant Setup Guide
2
+
3
+ This guide provides comprehensive instructions for reproducing the exact environment and setup for the LLaMA-Omni2 voice assistant with CosyVoice2 integration.
4
+
5
+ ## Prerequisites
6
+
7
+ - Ubuntu/Linux system with CUDA-capable GPU
8
+ - CUDA 12.1 or higher installed
9
+ - Miniconda or Anaconda installed
10
+ - At least 16GB RAM and 20GB free disk space
11
+ - Python 3.10
12
+
13
+ ## Environment Setup Options
14
+
15
+ ### Option 1: Using Conda Environment File (Recommended)
16
+
17
+ ```bash
18
+ # Create environment from comprehensive yml file
19
+ conda env create -f environment-comprehensive.yml
20
+
21
+ # Activate the environment
22
+ conda activate gsva-python310
23
+ ```
24
+
25
+ ### Option 2: Using Frozen Requirements
26
+
27
+ ```bash
28
+ # Create a new conda environment
29
+ conda create -n gsva-python310 python=3.10 -y
30
+ conda activate gsva-python310
31
+
32
+ # Install from frozen requirements
33
+ pip install -r requirements-frozen-new.txt
34
+ ```
35
+
36
+ ### Option 3: Manual Setup Using Script
37
+
38
+ ```bash
39
+ # Run the complete setup script
40
+ bash script.sh
41
+ ```
42
+
43
+ ## Detailed Manual Setup
44
+
45
+ ### 1. Create and Activate Conda Environment
46
+
47
+ ```bash
48
+ source /home/azureuser/miniconda3/etc/profile.d/conda.sh
49
+ conda create -n gsva-python310 python=3.10 -y
50
+ conda activate gsva-python310
51
+ ```
52
+
53
+ ### 2. Install Basic Dependencies
54
+
55
+ ```bash
56
+ pip install Cython numpy==1.26.4
57
+ pip install packaging wheel setuptools==69.5.1
58
+ ```
59
+
60
+ ### 3. Install the Package
61
+
62
+ ```bash
63
+ # Install in development mode
64
+ pip install -e .
65
+ ```
66
+
67
+ ### 4. Install Core Dependencies
68
+
69
+ ```bash
70
+ # Essential packages
71
+ pip install huggingface_hub==0.25.1
72
+ pip install uvicorn openai-whisper fastapi
73
+ pip install hf_transfer ninja
74
+
75
+ # Gradio for web interface
76
+ pip install gradio==5.3.0 gradio_client==1.4.2
77
+ ```
78
+
79
+ ### 5. Setup CUDA Environment
80
+
81
+ ```bash
82
+ # Link CUDA installation
83
+ sudo rm -rf /usr/local/cuda
84
+ sudo ln -s /usr/local/cuda-12.6 /usr/local/cuda
85
+ export PATH=/usr/local/cuda/bin:$PATH
86
+ export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
87
+ ```
88
+
89
+ ### 6. Install PyTorch with CUDA Support
90
+
91
+ ```bash
92
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
93
+ ```
94
+
95
+ ### 7. Install Flash Attention
96
+
97
+ ```bash
98
+ MAX_JOBS=4 pip install flash-attn --no-build-isolation
99
+ ```
100
+
101
+ ### 8. Install Transformers and Audio Libraries
102
+
103
+ ```bash
104
+ # Specific version for LLaMA-Omni2 compatibility
105
+ pip install transformers==4.43.4
106
+
107
+ # Audio processing libraries
108
+ pip install matcha-tts --no-build-isolation
109
+ pip install git+https://github.com/FunAudioLLM/CosyVoice.git
110
+
111
+ # Additional dependencies
112
+ pip install conformer onnxruntime hyperpyyaml==1.2.2 ruamel.yaml
113
+ ```
114
+
115
+ ## Model Downloads
116
+
117
+ ### 1. Download LLaMA-Omni2 Model
118
+
119
+ ```bash
120
+ mkdir -p models
121
+ huggingface-cli download ICTNLP/LLaMA-Omni2-3B --local-dir models/LLaMA-Omni2-3B
122
+ ```
123
+
124
+ ### 2. Download CosyVoice2 Model
125
+
126
+ ```bash
127
+ mkdir -p models/cosyvoice2
128
+ python -c "
129
+ from huggingface_hub import snapshot_download
130
+ import os
131
+ os.makedirs('models/cosyvoice2', exist_ok=True)
132
+ snapshot_download(
133
+ repo_id='FunAudioLLM/CosyVoice2-0.5B',
134
+ local_dir='models/cosyvoice2',
135
+ local_dir_use_symlinks=False
136
+ )
137
+ "
138
+ ```
139
+
140
+ ### 3. Fix CosyVoice Configuration
141
+
142
+ ```bash
143
+ # Create backup
144
+ cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice2.yaml.backup
145
+
146
+ # Copy to expected filename
147
+ cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice.yaml
148
+
149
+ # Remove problematic parameter
150
+ grep -v "mix_ratio" models/cosyvoice2/cosyvoice.yaml > models/cosyvoice2/cosyvoice_fixed.yaml
151
+ mv models/cosyvoice2/cosyvoice_fixed.yaml models/cosyvoice2/cosyvoice.yaml
152
+ ```
153
+
154
+ ## Running the Services
155
+
156
+ ### 1. Start Controller
157
+
158
+ ```bash
159
+ nohup python -m llama_omni2.serve.controller \
160
+ --host 0.0.0.0 \
161
+ --port 10000 > controller.log 2>&1 &
162
+ ```
163
+
164
+ ### 2. Start Model Worker
165
+
166
+ ```bash
167
+ nohup python -m llama_omni2.serve.model_worker \
168
+ --host 0.0.0.0 \
169
+ --controller http://localhost:10000 \
170
+ --port 40000 \
171
+ --worker http://localhost:40000 \
172
+ --model-path models/LLaMA-Omni2-3B \
173
+ --model-name LLaMA-Omni2-3B > worker.log 2>&1 &
174
+ ```
175
+
176
+ ### 3. Start Gradio Web Server
177
+
178
+ With CosyVoice2 vocoder:
179
+ ```bash
180
+ python -m llama_omni2.serve.gradio_web_server \
181
+ --controller http://localhost:10000 \
182
+ --port 8000 \
183
+ --vocoder-dir models/cosyvoice2
184
+ ```
185
+
186
+ Without vocoder (fallback):
187
+ ```bash
188
+ python -m llama_omni2.serve.gradio_web_server \
189
+ --controller http://localhost:10000 \
190
+ --port 8000
191
+ ```
192
+
193
+ ## Monitoring Services
194
+
195
+ ```bash
196
+ # Check controller logs
197
+ tail -f controller.log
198
+
199
+ # Check model worker logs
200
+ tail -f worker.log
201
+
202
+ # Access web UI
203
+ # Open browser at http://localhost:8000
204
+ ```
205
+
206
+ ## Troubleshooting
207
+
208
+ ### Common Issues
209
+
210
+ 1. **CUDA not found**: Ensure CUDA paths are exported correctly
211
+ 2. **Flash attention build fails**: Use `MAX_JOBS=4` to limit parallel compilation
212
+ 3. **CosyVoice mix_ratio error**: Follow the configuration fix steps above
213
+ 4. **Port already in use**: Kill existing processes or use different ports
214
+
215
+ ### Killing Services
216
+
217
+ ```bash
218
+ # Find and kill Python processes
219
+ ps aux | grep python | grep -E "(controller|model_worker|gradio_web_server)" | awk '{print $2}' | xargs -r kill
220
+ ```
221
+
222
+ ## Project Structure
223
+
224
+ ```
225
+ voiceagents/
226
+ ├── llama_omni2/ # Main application code
227
+ ├── cosyvoice/ # CosyVoice integration
228
+ ├── models/ # Downloaded models
229
+ │ ├── LLaMA-Omni2-3B/
230
+ │ └── cosyvoice2/
231
+ ├── examples/ # Sample audio files
232
+ ├── script.sh # Setup script
233
+ ├── pyproject.toml # Project configuration
234
+ ├── requirements-frozen-new.txt # Frozen dependencies
235
+ ├── environment-comprehensive.yml # Conda environment
236
+ └── SETUP_GUIDE.md # This file
237
+ ```
238
+
239
+ ## Environment Variables
240
+
241
+ Set these in your `.bashrc` or `.zshrc`:
242
+
243
+ ```bash
244
+ export PATH=/usr/local/cuda/bin:$PATH
245
+ export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
246
+ export HF_HUB_ENABLE_HF_TRANSFER=1
247
+ export HF_HOME=~/.cache/huggingface
248
+ export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0"
249
+ export MAX_JOBS=4
250
+ ```
251
+
252
+ ## Version Information
253
+
254
+ - Python: 3.10
255
+ - PyTorch: 2.3.1
256
+ - Transformers: 4.43.4
257
+ - Gradio: 5.3.0
258
+ - CUDA: 12.1+
259
+ - CosyVoice2: 0.5B model
260
+
261
+ ## Additional Notes
262
+
263
+ - The setup has been tested on Ubuntu with NVIDIA GPUs
264
+ - Ensure sufficient GPU memory (8GB+ recommended)
265
+ - For production deployment, consider using systemd services
266
+ - Regular backups of models and configurations are recommended
267
+
268
+ ## Support
269
+
270
+ For issues or questions:
271
+ - Check the logs in controller.log, worker.log
272
+ - Ensure all dependencies are correctly installed
273
+ - Verify CUDA is properly configured
274
+ - Review the COSYVOICE2_CHANGES.md for model-specific details
controller.log.2025-08-16 ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 2025-08-16 15:21:01 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue')
2
+ 2025-08-16 15:21:01 | INFO | controller | Init controller
3
+ 2025-08-16 15:21:01 | ERROR | stderr | INFO: Started server process [32029]
4
+ 2025-08-16 15:21:01 | ERROR | stderr | INFO: Waiting for application startup.
5
+ 2025-08-16 15:21:01 | ERROR | stderr | INFO: Application startup complete.
6
+ 2025-08-16 15:21:01 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit)
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/average_model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in avg.keys():
79
+ avg[k] = states[k].clone()
80
+ else:
81
+ avg[k] += states[k]
82
+ # average
83
+ for k in avg.keys():
84
+ if avg[k] is not None:
85
+ # pytorch 1.6 use true_divide instead of /=
86
+ avg[k] = torch.true_divide(avg[k], num)
87
+ print('Saving to {}'.format(args.dst_model))
88
+ torch.save(avg, args.dst_model)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='export your model for deployment')
31
+ parser.add_argument('--model_dir',
32
+ type=str,
33
+ default='pretrained_models/CosyVoice-300M',
34
+ help='local path')
35
+ args = parser.parse_args()
36
+ print(args)
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+ logging.basicConfig(level=logging.DEBUG,
43
+ format='%(asctime)s %(levelname)s %(message)s')
44
+
45
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
46
+ torch._C._jit_set_profiling_mode(False)
47
+ torch._C._jit_set_profiling_executor(False)
48
+
49
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
50
+
51
+ # 1. export llm text_encoder
52
+ llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
53
+ script = torch.jit.script(llm_text_encoder)
54
+ script = torch.jit.freeze(script)
55
+ script = torch.jit.optimize_for_inference(script)
56
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
57
+
58
+ # 2. export llm llm
59
+ llm_llm = cosyvoice.model.llm.llm.half()
60
+ script = torch.jit.script(llm_llm)
61
+ script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
62
+ script = torch.jit.optimize_for_inference(script)
63
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
64
+
65
+ # 3. export flow encoder
66
+ flow_encoder = cosyvoice.model.flow.encoder
67
+ script = torch.jit.script(flow_encoder)
68
+ script = torch.jit.freeze(script)
69
+ script = torch.jit.optimize_for_inference(script)
70
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60
+
61
+ # 1. export flow decoder estimator
62
+ estimator = cosyvoice.model.flow.decoder.estimator
63
+
64
+ device = cosyvoice.model.device
65
+ batch_size, seq_len = 1, 256
66
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68
+ torch.onnx.export(
69
+ estimator,
70
+ (x, mask, mu, t, spks, cond),
71
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72
+ export_params=True,
73
+ opset_version=18,
74
+ do_constant_folding=True,
75
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76
+ output_names=['estimator_out'],
77
+ dynamic_axes={
78
+ 'x': {0: 'batch_size', 2: 'seq_len'},
79
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
80
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
81
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
82
+ 't': {0: 'batch_size'},
83
+ 'spks': {0: 'batch_size'},
84
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
85
+ }
86
+ )
87
+
88
+ # 2. test computation consistency
89
+ option = onnxruntime.SessionOptions()
90
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
+ option.intra_op_num_threads = 1
92
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94
+ sess_options=option, providers=providers)
95
+
96
+ for _ in tqdm(range(10)):
97
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
98
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
99
+ ort_inputs = {
100
+ 'x': x.cpu().numpy(),
101
+ 'mask': mask.cpu().numpy(),
102
+ 'mu': mu.cpu().numpy(),
103
+ 't': t.cpu().numpy(),
104
+ 'spks': spks.cpu().numpy(),
105
+ 'cond': cond.cpu().numpy()
106
+ }
107
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
cosyvoice/bin/export_trt.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2024 Alibaba Inc. All Rights Reserved.
3
+ # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
4
+ # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
5
+ TRT_DIR=<YOUR_TRT_DIR>
6
+ MODEL_DIR=<COSYVOICE2_MODEL_DIR>
7
+
8
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
9
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--llm_model', required=True, help='llm model file')
37
+ parser.add_argument('--flow_model', required=True, help='flow model file')
38
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
39
+ parser.add_argument('--gpu',
40
+ type=int,
41
+ default=-1,
42
+ help='gpu id for this rank, -1 for cpu')
43
+ parser.add_argument('--mode',
44
+ default='sft',
45
+ choices=['sft', 'zero_shot'],
46
+ help='inference mode')
47
+ parser.add_argument('--result_dir', required=True, help='asr result file')
48
+ args = parser.parse_args()
49
+ print(args)
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58
+
59
+ # Init cosyvoice models from configs
60
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
+ device = torch.device('cuda' if use_cuda else 'cpu')
62
+ with open(args.config, 'r') as f:
63
+ configs = load_hyperpyyaml(f)
64
+
65
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
+
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text_token = batch["text_token"].to(device)
81
+ text_token_len = batch["text_token_len"].to(device)
82
+ tts_index = batch["tts_index"]
83
+ tts_text_token = batch["tts_text_token"].to(device)
84
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
85
+ speech_token = batch["speech_token"].to(device)
86
+ speech_token_len = batch["speech_token_len"].to(device)
87
+ speech_feat = batch["speech_feat"].to(device)
88
+ speech_feat_len = batch["speech_feat_len"].to(device)
89
+ utt_embedding = batch["utt_embedding"].to(device)
90
+ spk_embedding = batch["spk_embedding"].to(device)
91
+ if args.mode == 'sft':
92
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94
+ else:
95
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101
+ tts_speeches = []
102
+ for model_output in model.tts(**model_input):
103
+ tts_speeches.append(model_output['tts_speech'])
104
+ tts_speeches = torch.concat(tts_speeches, dim=1)
105
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108
+ f.write('{} {}\n'.format(tts_key, tts_fn))
109
+ f.flush()
110
+ f.close()
111
+ logging.info('Result wav.scp saved in {}'.format(fn))
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor import Executor
31
+ from cosyvoice.utils.train_utils import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--checkpoint', help='checkpoint model')
50
+ parser.add_argument('--model_dir', required=True, help='save model dir')
51
+ parser.add_argument('--tensorboard_dir',
52
+ default='tensorboard',
53
+ help='tensorboard log dir')
54
+ parser.add_argument('--ddp.dist_backend',
55
+ dest='dist_backend',
56
+ default='nccl',
57
+ choices=['nccl', 'gloo'],
58
+ help='distributed backend')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--prefetch',
64
+ default=100,
65
+ type=int,
66
+ help='prefetch number')
67
+ parser.add_argument('--pin_memory',
68
+ action='store_true',
69
+ default=False,
70
+ help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
75
+ parser.add_argument('--deepspeed.save_states',
76
+ dest='save_states',
77
+ default='model_only',
78
+ choices=['model_only', 'model+optimizer'],
79
+ help='save model/optimizer states')
80
+ parser.add_argument('--timeout',
81
+ default=60,
82
+ type=int,
83
+ help='timeout (in seconds) of cosyvoice_join.')
84
+ parser = deepspeed.add_config_arguments(parser)
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ @record
90
+ def main():
91
+ args = get_args()
92
+ logging.basicConfig(level=logging.DEBUG,
93
+ format='%(asctime)s %(levelname)s %(message)s')
94
+ # gan train has some special initialization logic
95
+ gan = True if args.model == 'hifigan' else False
96
+
97
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
+ if gan is True:
99
+ override_dict.pop('hift')
100
+ with open(args.config, 'r') as f:
101
+ configs = load_hyperpyyaml(f, overrides=override_dict)
102
+ if gan is True:
103
+ configs['train_conf'] = configs['train_conf_gan']
104
+ configs['train_conf'].update(vars(args))
105
+
106
+ # Init env for ddp
107
+ init_distributed(args)
108
+
109
+ # Get dataset & dataloader
110
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
111
+ init_dataset_and_dataloader(args, configs, gan)
112
+
113
+ # Do some sanity checks and save config to arsg.model_dir
114
+ configs = check_modify_and_save_config(args, configs)
115
+
116
+ # Tensorboard summary
117
+ writer = init_summarywriter(args)
118
+
119
+ # load checkpoint
120
+ model = configs[args.model]
121
+ start_step, start_epoch = 0, -1
122
+ if args.checkpoint is not None:
123
+ if os.path.exists(args.checkpoint):
124
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
125
+ model.load_state_dict(state_dict, strict=False)
126
+ if 'step' in state_dict:
127
+ start_step = state_dict['step']
128
+ if 'epoch' in state_dict:
129
+ start_epoch = state_dict['epoch']
130
+ else:
131
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
132
+
133
+ # Dispatch model from cpu to gpu
134
+ model = wrap_cuda_model(args, model)
135
+
136
+ # Get optimizer & scheduler
137
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
138
+ scheduler.set_step(start_step)
139
+ if scheduler_d is not None:
140
+ scheduler_d.set_step(start_step)
141
+
142
+ # Save init checkpoints
143
+ info_dict = deepcopy(configs['train_conf'])
144
+ info_dict['step'] = start_step
145
+ info_dict['epoch'] = start_epoch
146
+ save_model(model, 'init', info_dict)
147
+
148
+ # Get executor
149
+ executor = Executor(gan=gan)
150
+ executor.step = start_step
151
+
152
+ # Init scaler, used for pytorch amp mixed precision training
153
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
154
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
155
+ # Start training loop
156
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
157
+ executor.epoch = epoch
158
+ train_dataset.set_epoch(epoch)
159
+ dist.barrier()
160
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
161
+ if gan is True:
162
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
163
+ writer, info_dict, scaler, group_join)
164
+ else:
165
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
166
+ dist.destroy_process_group(group_join)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from tqdm import tqdm
17
+ from hyperpyyaml import load_hyperpyyaml
18
+ from modelscope import snapshot_download
19
+ import torch
20
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
21
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
22
+ from cosyvoice.utils.file_utils import logging
23
+
24
+
25
+ class CosyVoice:
26
+
27
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
28
+ instruct = True if '-Instruct' in model_dir else False
29
+ self.model_dir = model_dir
30
+ if not os.path.exists(model_dir):
31
+ model_dir = snapshot_download(model_dir)
32
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
33
+ configs = load_hyperpyyaml(f)
34
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
35
+ configs['feat_extractor'],
36
+ '{}/campplus.onnx'.format(model_dir),
37
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
38
+ '{}/spk2info.pt'.format(model_dir),
39
+ instruct,
40
+ configs['allowed_special'])
41
+ self.sample_rate = configs['sample_rate']
42
+ if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
43
+ load_jit = False
44
+ fp16 = False
45
+ logging.warning('cpu do not support fp16 and jit, force set to False')
46
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
47
+ self.model.load('{}/llm.pt'.format(model_dir),
48
+ '{}/flow.pt'.format(model_dir),
49
+ '{}/hift.pt'.format(model_dir))
50
+ if load_jit:
51
+ self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
52
+ '{}/llm.llm.fp16.zip'.format(model_dir),
53
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
54
+ if load_onnx:
55
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
56
+ del configs
57
+
58
+ def list_avaliable_spks(self):
59
+ spks = list(self.frontend.spk2info.keys())
60
+ return spks
61
+
62
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
63
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
64
+ model_input = self.frontend.frontend_sft(i, spk_id)
65
+ start_time = time.time()
66
+ logging.info('synthesis text {}'.format(i))
67
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
68
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
69
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
70
+ yield model_output
71
+ start_time = time.time()
72
+
73
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
74
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
75
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
76
+ if len(i) < 0.5 * len(prompt_text):
77
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
78
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
79
+ start_time = time.time()
80
+ logging.info('synthesis text {}'.format(i))
81
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
82
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
83
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
84
+ yield model_output
85
+ start_time = time.time()
86
+
87
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
88
+ if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
89
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
90
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
91
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
92
+ start_time = time.time()
93
+ logging.info('synthesis text {}'.format(i))
94
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
95
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
96
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
97
+ yield model_output
98
+ start_time = time.time()
99
+
100
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
101
+ assert isinstance(self.model, CosyVoiceModel)
102
+ if self.frontend.instruct is False:
103
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
104
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
105
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
106
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
107
+ start_time = time.time()
108
+ logging.info('synthesis text {}'.format(i))
109
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
110
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
111
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
112
+ yield model_output
113
+ start_time = time.time()
114
+
115
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
116
+ assert isinstance(self.model, CosyVoice2Model)
117
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
118
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
119
+ start_time = time.time()
120
+ logging.info('synthesis text {}'.format(i))
121
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
122
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
123
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
124
+ yield model_output
125
+ start_time = time.time()
126
+
127
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
128
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
129
+ start_time = time.time()
130
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
131
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
132
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
133
+ yield model_output
134
+ start_time = time.time()
135
+
136
+
137
+ class CosyVoice2(CosyVoice):
138
+
139
+ def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
140
+ instruct = True if '-Instruct' in model_dir else False
141
+ self.model_dir = model_dir
142
+ if not os.path.exists(model_dir):
143
+ model_dir = snapshot_download(model_dir)
144
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
145
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
146
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
147
+ configs['feat_extractor'],
148
+ '{}/campplus.onnx'.format(model_dir),
149
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
150
+ '{}/spk2info.pt'.format(model_dir),
151
+ instruct,
152
+ configs['allowed_special'])
153
+ self.sample_rate = configs['sample_rate']
154
+ if torch.cuda.is_available() is False and load_jit is True:
155
+ load_jit = False
156
+ logging.warning('cpu do not support jit, force set to False')
157
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
158
+ self.model.load('{}/llm.pt'.format(model_dir),
159
+ '{}/flow.pt'.format(model_dir),
160
+ '{}/hift.pt'.format(model_dir))
161
+ if load_jit:
162
+ self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
163
+ if load_trt is True and load_onnx is True:
164
+ load_onnx = False
165
+ logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
166
+ if load_onnx:
167
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
168
+ if load_trt:
169
+ self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
170
+ del configs
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import json
16
+ import onnxruntime
17
+ import torch
18
+ import numpy as np
19
+ import whisper
20
+ from typing import Callable
21
+ import torchaudio.compliance.kaldi as kaldi
22
+ import torchaudio
23
+ import os
24
+ import re
25
+ import inflect
26
+ try:
27
+ import ttsfrd
28
+ use_ttsfrd = True
29
+ except ImportError:
30
+ print("failed to import ttsfrd, use WeTextProcessing instead")
31
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
32
+ from tn.english.normalizer import Normalizer as EnNormalizer
33
+ use_ttsfrd = False
34
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
35
+
36
+
37
+ class CosyVoiceFrontEnd:
38
+
39
+ def __init__(self,
40
+ get_tokenizer: Callable,
41
+ feat_extractor: Callable,
42
+ campplus_model: str,
43
+ speech_tokenizer_model: str,
44
+ spk2info: str = '',
45
+ instruct: bool = False,
46
+ allowed_special: str = 'all'):
47
+ self.tokenizer = get_tokenizer()
48
+ self.feat_extractor = feat_extractor
49
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
+ option = onnxruntime.SessionOptions()
51
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
52
+ option.intra_op_num_threads = 1
53
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
54
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
55
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
56
+ "CPUExecutionProvider"])
57
+ if os.path.exists(spk2info):
58
+ self.spk2info = torch.load(spk2info, map_location=self.device)
59
+ else:
60
+ self.spk2info = {}
61
+ self.instruct = instruct
62
+ self.allowed_special = allowed_special
63
+ self.inflect_parser = inflect.engine()
64
+ self.use_ttsfrd = use_ttsfrd
65
+ if self.use_ttsfrd:
66
+ self.frd = ttsfrd.TtsFrontendEngine()
67
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
68
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
69
+ 'failed to initialize ttsfrd resource'
70
+ self.frd.set_lang_type('pinyinvg')
71
+ else:
72
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
73
+ self.en_tn_model = EnNormalizer()
74
+
75
+ def _extract_text_token(self, text):
76
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
77
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
78
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
79
+ return text_token, text_token_len
80
+
81
+ def _extract_speech_token(self, speech):
82
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
83
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
84
+ speech_token = self.speech_tokenizer_session.run(None,
85
+ {self.speech_tokenizer_session.get_inputs()[0].name:
86
+ feat.detach().cpu().numpy(),
87
+ self.speech_tokenizer_session.get_inputs()[1].name:
88
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
89
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
90
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
91
+ return speech_token, speech_token_len
92
+
93
+ def _extract_spk_embedding(self, speech):
94
+ feat = kaldi.fbank(speech,
95
+ num_mel_bins=80,
96
+ dither=0,
97
+ sample_frequency=16000)
98
+ feat = feat - feat.mean(dim=0, keepdim=True)
99
+ embedding = self.campplus_session.run(None,
100
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
101
+ embedding = torch.tensor([embedding]).to(self.device)
102
+ return embedding
103
+
104
+ def _extract_speech_feat(self, speech):
105
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
106
+ speech_feat = speech_feat.unsqueeze(dim=0)
107
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
108
+ return speech_feat, speech_feat_len
109
+
110
+ def text_normalize(self, text, split=True):
111
+ text = text.strip()
112
+ # NOTE(lyuxiang.lx) move this judgement into ttsfrd in the future
113
+ for token in self.tokenizer.special_tokens['additional_special_tokens']:
114
+ if token in text:
115
+ return text if split is False else [text]
116
+ if contains_chinese(text):
117
+ if self.use_ttsfrd:
118
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
119
+ text = ''.join(texts)
120
+ else:
121
+ text = self.zh_tn_model.normalize(text)
122
+ text = text.replace("\n", "")
123
+ text = replace_blank(text)
124
+ text = replace_corner_mark(text)
125
+ text = text.replace(".", "。")
126
+ text = text.replace(" - ", ",")
127
+ text = remove_bracket(text)
128
+ text = re.sub(r'[,,、]+$', '。', text)
129
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
130
+ token_min_n=60, merge_len=20, comma_split=False))
131
+ else:
132
+ if self.use_ttsfrd:
133
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
134
+ text = ''.join(texts)
135
+ else:
136
+ text = self.en_tn_model.normalize(text)
137
+ text = spell_out_number(text, self.inflect_parser)
138
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
139
+ token_min_n=60, merge_len=20, comma_split=False))
140
+ if split is False:
141
+ return text
142
+ return texts
143
+
144
+ def frontend_sft(self, tts_text, spk_id):
145
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
146
+ embedding = self.spk2info[spk_id]['embedding']
147
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
148
+ return model_input
149
+
150
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
151
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
152
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
153
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
154
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
155
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
156
+ if resample_rate == 24000:
157
+ # cosyvoice2, force speech_feat % speech_token = 2
158
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
159
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
160
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
161
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
162
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
163
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
164
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
165
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
166
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
167
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
168
+ return model_input
169
+
170
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
171
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
172
+ # in cross lingual mode, we remove prompt in llm
173
+ del model_input['prompt_text']
174
+ del model_input['prompt_text_len']
175
+ del model_input['llm_prompt_speech_token']
176
+ del model_input['llm_prompt_speech_token_len']
177
+ return model_input
178
+
179
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
180
+ model_input = self.frontend_sft(tts_text, spk_id)
181
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
182
+ del model_input['llm_embedding']
183
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
184
+ model_input['prompt_text'] = instruct_text_token
185
+ model_input['prompt_text_len'] = instruct_text_token_len
186
+ return model_input
187
+
188
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
189
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
190
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
191
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
192
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
193
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
194
+ if resample_rate == 24000:
195
+ # cosyvoice2, force speech_feat % speech_token = 2
196
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
197
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
198
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
199
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
200
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
201
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
202
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
203
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
204
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
205
+ return model_input
206
+
207
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
208
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
209
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
210
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
211
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
212
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
213
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
214
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
215
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
216
+ 'flow_embedding': embedding}
217
+ return model_input
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import numpy as np
16
+ import threading
17
+ import time
18
+ from torch.nn import functional as F
19
+ from contextlib import nullcontext
20
+ import uuid
21
+ from cosyvoice.utils.common import fade_in_out
22
+
23
+
24
+ class CosyVoiceModel:
25
+
26
+ def __init__(self,
27
+ llm: torch.nn.Module,
28
+ flow: torch.nn.Module,
29
+ hift: torch.nn.Module,
30
+ fp16: bool):
31
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ self.llm = llm
33
+ self.flow = flow
34
+ self.hift = hift
35
+ self.fp16 = fp16
36
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
37
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
38
+ self.token_overlap_len = 20
39
+ # mel fade in out
40
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
41
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
42
+ # hift cache
43
+ self.mel_cache_len = 20
44
+ self.source_cache_len = int(self.mel_cache_len * 256)
45
+ # speech fade in out
46
+ self.speech_window = np.hamming(2 * self.source_cache_len)
47
+ # rtf and decoding related
48
+ self.stream_scale_factor = 1
49
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
50
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
51
+ self.lock = threading.Lock()
52
+ # dict used to store session related variable
53
+ self.tts_speech_token_dict = {}
54
+ self.llm_end_dict = {}
55
+ self.mel_overlap_dict = {}
56
+ self.flow_cache_dict = {}
57
+ self.hift_cache_dict = {}
58
+
59
+ def load(self, llm_model, flow_model, hift_model):
60
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
61
+ self.llm.to(self.device).eval()
62
+ if self.fp16 is True:
63
+ self.llm.half()
64
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
65
+ self.flow.to(self.device).eval()
66
+ # in case hift_model is a hifigan model
67
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
68
+ self.hift.load_state_dict(hift_state_dict, strict=True)
69
+ self.hift.to(self.device).eval()
70
+
71
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
72
+ assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
73
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
74
+ self.llm.text_encoder = llm_text_encoder
75
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
76
+ self.llm.llm = llm_llm
77
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
78
+ self.flow.encoder = flow_encoder
79
+
80
+ def load_onnx(self, flow_decoder_estimator_model):
81
+ import onnxruntime
82
+ option = onnxruntime.SessionOptions()
83
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
84
+ option.intra_op_num_threads = 1
85
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
86
+ del self.flow.decoder.estimator
87
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
88
+
89
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
90
+ if self.fp16 is True:
91
+ llm_embedding = llm_embedding.half()
92
+ with self.llm_context:
93
+ for i in self.llm.inference(text=text.to(self.device),
94
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
95
+ prompt_text=prompt_text.to(self.device),
96
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
97
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
98
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
99
+ embedding=llm_embedding.to(self.device)):
100
+ self.tts_speech_token_dict[uuid].append(i)
101
+ self.llm_end_dict[uuid] = True
102
+
103
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
104
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
105
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
106
+ prompt_token=prompt_token.to(self.device),
107
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
108
+ prompt_feat=prompt_feat.to(self.device),
109
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
110
+ embedding=embedding.to(self.device),
111
+ flow_cache=self.flow_cache_dict[uuid])
112
+ self.flow_cache_dict[uuid] = flow_cache
113
+
114
+ # mel overlap fade in out
115
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
116
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
117
+ # append hift cache
118
+ if self.hift_cache_dict[uuid] is not None:
119
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
120
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
121
+ else:
122
+ hift_cache_source = torch.zeros(1, 1, 0)
123
+ # keep overlap mel and hift cache
124
+ if finalize is False:
125
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
126
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
127
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
128
+ if self.hift_cache_dict[uuid] is not None:
129
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
130
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
131
+ 'source': tts_source[:, :, -self.source_cache_len:],
132
+ 'speech': tts_speech[:, -self.source_cache_len:]}
133
+ tts_speech = tts_speech[:, :-self.source_cache_len]
134
+ else:
135
+ if speed != 1.0:
136
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
137
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
138
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
139
+ if self.hift_cache_dict[uuid] is not None:
140
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
141
+ return tts_speech
142
+
143
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
144
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
145
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
146
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
147
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
148
+ # this_uuid is used to track variables related to this inference thread
149
+ this_uuid = str(uuid.uuid1())
150
+ with self.lock:
151
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
152
+ self.hift_cache_dict[this_uuid] = None
153
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
154
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
155
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
156
+ p.start()
157
+ if stream is True:
158
+ token_hop_len = self.token_min_hop_len
159
+ while True:
160
+ time.sleep(0.1)
161
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
162
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
163
+ .unsqueeze(dim=0)
164
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
165
+ prompt_token=flow_prompt_speech_token,
166
+ prompt_feat=prompt_speech_feat,
167
+ embedding=flow_embedding,
168
+ uuid=this_uuid,
169
+ finalize=False)
170
+ yield {'tts_speech': this_tts_speech.cpu()}
171
+ with self.lock:
172
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
173
+ # increase token_hop_len for better speech quality
174
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
175
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
176
+ break
177
+ p.join()
178
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
179
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
180
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
181
+ prompt_token=flow_prompt_speech_token,
182
+ prompt_feat=prompt_speech_feat,
183
+ embedding=flow_embedding,
184
+ uuid=this_uuid,
185
+ finalize=True)
186
+ yield {'tts_speech': this_tts_speech.cpu()}
187
+ else:
188
+ # deal with all tokens
189
+ p.join()
190
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
191
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
192
+ prompt_token=flow_prompt_speech_token,
193
+ prompt_feat=prompt_speech_feat,
194
+ embedding=flow_embedding,
195
+ uuid=this_uuid,
196
+ finalize=True,
197
+ speed=speed)
198
+ yield {'tts_speech': this_tts_speech.cpu()}
199
+ with self.lock:
200
+ self.tts_speech_token_dict.pop(this_uuid)
201
+ self.llm_end_dict.pop(this_uuid)
202
+ self.mel_overlap_dict.pop(this_uuid)
203
+ self.hift_cache_dict.pop(this_uuid)
204
+ self.flow_cache_dict.pop(this_uuid)
205
+
206
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
207
+ # this_uuid is used to track variables related to this inference thread
208
+ this_uuid = str(uuid.uuid1())
209
+ with self.lock:
210
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
211
+ self.hift_cache_dict[this_uuid] = None
212
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
213
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
214
+ if stream is True:
215
+ token_hop_len = self.token_min_hop_len
216
+ while True:
217
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
218
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
219
+ .unsqueeze(dim=0)
220
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
221
+ prompt_token=flow_prompt_speech_token,
222
+ prompt_feat=prompt_speech_feat,
223
+ embedding=flow_embedding,
224
+ uuid=this_uuid,
225
+ finalize=False)
226
+ yield {'tts_speech': this_tts_speech.cpu()}
227
+ with self.lock:
228
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
229
+ # increase token_hop_len for better speech quality
230
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
231
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
232
+ break
233
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
234
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
235
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
236
+ prompt_token=flow_prompt_speech_token,
237
+ prompt_feat=prompt_speech_feat,
238
+ embedding=flow_embedding,
239
+ uuid=this_uuid,
240
+ finalize=True)
241
+ yield {'tts_speech': this_tts_speech.cpu()}
242
+ else:
243
+ # deal with all tokens
244
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
245
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
246
+ prompt_token=flow_prompt_speech_token,
247
+ prompt_feat=prompt_speech_feat,
248
+ embedding=flow_embedding,
249
+ uuid=this_uuid,
250
+ finalize=True,
251
+ speed=speed)
252
+ yield {'tts_speech': this_tts_speech.cpu()}
253
+ with self.lock:
254
+ self.tts_speech_token_dict.pop(this_uuid)
255
+ self.llm_end_dict.pop(this_uuid)
256
+ self.mel_overlap_dict.pop(this_uuid)
257
+ self.hift_cache_dict.pop(this_uuid)
258
+
259
+
260
+ class CosyVoice2Model:
261
+
262
+ def __init__(self,
263
+ llm: torch.nn.Module,
264
+ flow: torch.nn.Module,
265
+ hift: torch.nn.Module):
266
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
267
+ self.llm = llm
268
+ self.flow = flow
269
+ self.hift = hift
270
+ self.token_hop_len = 2 * self.flow.input_frame_rate
271
+ # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
272
+ self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
273
+ self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
274
+ # hift cache
275
+ self.mel_cache_len = 8
276
+ self.source_cache_len = int(self.mel_cache_len * 480)
277
+ # speech fade in out
278
+ self.speech_window = np.hamming(2 * self.source_cache_len)
279
+ # rtf and decoding related
280
+ self.stream_scale_factor = 1
281
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
282
+ self.lock = threading.Lock()
283
+ # dict used to store session related variable
284
+ self.tts_speech_token_dict = {}
285
+ self.llm_end_dict = {}
286
+ self.hift_cache_dict = {}
287
+
288
+ def load(self, llm_model, flow_model, hift_model):
289
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
290
+ self.llm.to(self.device).eval()
291
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
292
+ self.flow.to(self.device).eval()
293
+ self.flow.decoder.fp16 = False
294
+ # in case hift_model is a hifigan model
295
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
296
+ self.hift.load_state_dict(hift_state_dict, strict=True)
297
+ self.hift.to(self.device).eval()
298
+
299
+ def load_jit(self, flow_encoder_model):
300
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
301
+ self.flow.encoder = flow_encoder
302
+
303
+ def load_onnx(self, flow_decoder_estimator_model):
304
+ import onnxruntime
305
+ option = onnxruntime.SessionOptions()
306
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
307
+ option.intra_op_num_threads = 1
308
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
309
+ del self.flow.decoder.estimator
310
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
311
+
312
+ def load_trt(self, flow_decoder_estimator_model):
313
+ del self.flow.decoder.estimator
314
+ import tensorrt as trt
315
+ with open(flow_decoder_estimator_model, 'rb') as f:
316
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
317
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
318
+ self.flow.decoder.fp16 = True
319
+
320
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
321
+ with self.llm_context:
322
+ for i in self.llm.inference(text=text.to(self.device),
323
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
324
+ prompt_text=prompt_text.to(self.device),
325
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
326
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
327
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
328
+ embedding=llm_embedding.to(self.device)):
329
+ self.tts_speech_token_dict[uuid].append(i)
330
+ self.llm_end_dict[uuid] = True
331
+
332
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
333
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
334
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
335
+ prompt_token=prompt_token.to(self.device),
336
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
337
+ prompt_feat=prompt_feat.to(self.device),
338
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
339
+ embedding=embedding.to(self.device),
340
+ finalize=finalize)
341
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
342
+ # append hift cache
343
+ if self.hift_cache_dict[uuid] is not None:
344
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
345
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
346
+ else:
347
+ hift_cache_source = torch.zeros(1, 1, 0)
348
+ # keep overlap mel and hift cache
349
+ if finalize is False:
350
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
351
+ if self.hift_cache_dict[uuid] is not None:
352
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
353
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
354
+ 'source': tts_source[:, :, -self.source_cache_len:],
355
+ 'speech': tts_speech[:, -self.source_cache_len:]}
356
+ tts_speech = tts_speech[:, :-self.source_cache_len]
357
+ else:
358
+ if speed != 1.0:
359
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
360
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
361
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
362
+ if self.hift_cache_dict[uuid] is not None:
363
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
364
+ return tts_speech
365
+
366
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
367
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
368
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
369
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
370
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
371
+ # this_uuid is used to track variables related to this inference thread
372
+ this_uuid = str(uuid.uuid1())
373
+ with self.lock:
374
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
375
+ self.hift_cache_dict[this_uuid] = None
376
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
377
+ p.start()
378
+ if stream is True:
379
+ token_offset = 0
380
+ while True:
381
+ time.sleep(0.1)
382
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
383
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
384
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
385
+ prompt_token=flow_prompt_speech_token,
386
+ prompt_feat=prompt_speech_feat,
387
+ embedding=flow_embedding,
388
+ uuid=this_uuid,
389
+ token_offset=token_offset,
390
+ finalize=False)
391
+ token_offset += self.token_hop_len
392
+ yield {'tts_speech': this_tts_speech.cpu()}
393
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
394
+ break
395
+ p.join()
396
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
397
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
398
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
399
+ prompt_token=flow_prompt_speech_token,
400
+ prompt_feat=prompt_speech_feat,
401
+ embedding=flow_embedding,
402
+ uuid=this_uuid,
403
+ token_offset=token_offset,
404
+ finalize=True)
405
+ yield {'tts_speech': this_tts_speech.cpu()}
406
+ else:
407
+ # deal with all tokens
408
+ p.join()
409
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
410
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
411
+ prompt_token=flow_prompt_speech_token,
412
+ prompt_feat=prompt_speech_feat,
413
+ embedding=flow_embedding,
414
+ uuid=this_uuid,
415
+ token_offset=0,
416
+ finalize=True,
417
+ speed=speed)
418
+ yield {'tts_speech': this_tts_speech.cpu()}
419
+ with self.lock:
420
+ self.tts_speech_token_dict.pop(this_uuid)
421
+ self.llm_end_dict.pop(this_uuid)
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shards tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+
17
+ import pyarrow.parquet as pq
18
+ from io import BytesIO
19
+ import torch
20
+ import torchaudio
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ import torch.nn.functional as F
23
+
24
+ torchaudio.set_audio_backend('soundfile')
25
+
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
+
28
+
29
+ def parquet_opener(data, mode='train', tts_data={}):
30
+ """ Give url or local file, return file descriptor
31
+ Inplace operation.
32
+
33
+ Args:
34
+ data(Iterable[str]): url or local file list
35
+
36
+ Returns:
37
+ Iterable[{src, stream}]
38
+ """
39
+ for sample in data:
40
+ assert 'src' in sample
41
+ url = sample['src']
42
+ try:
43
+ for df in pq.ParquetFile(url).iter_batches(batch_size=64):
44
+ df = df.to_pandas()
45
+ for i in range(len(df)):
46
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
47
+ continue
48
+ sample.update(dict(df.loc[i]))
49
+ if mode == 'train':
50
+ # NOTE do not return sample directly, must initialize a new dict
51
+ yield {**sample}
52
+ else:
53
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
54
+ yield {**sample, 'tts_index': index, 'tts_text': text}
55
+ except Exception as ex:
56
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
57
+
58
+
59
+ def filter(data,
60
+ max_length=10240,
61
+ min_length=10,
62
+ token_max_length=200,
63
+ token_min_length=1,
64
+ min_output_input_ratio=0.0005,
65
+ max_output_input_ratio=1,
66
+ mode='train'):
67
+ """ Filter sample according to feature and label length
68
+ Inplace operation.
69
+
70
+ Args::
71
+ data: Iterable[{key, wav, label, sample_rate}]
72
+ max_length: drop utterance which is greater than max_length(10ms)
73
+ min_length: drop utterance which is less than min_length(10ms)
74
+ token_max_length: drop utterance which is greater than
75
+ token_max_length, especially when use char unit for
76
+ english modeling
77
+ token_min_length: drop utterance which is
78
+ less than token_max_length
79
+ min_output_input_ratio: minimal ration of
80
+ token_length / feats_length(10ms)
81
+ max_output_input_ratio: maximum ration of
82
+ token_length / feats_length(10ms)
83
+
84
+ Returns:
85
+ Iterable[{key, wav, label, sample_rate}]
86
+ """
87
+ for sample in data:
88
+ sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
89
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
90
+ del sample['audio_data']
91
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
92
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
93
+ if num_frames < min_length:
94
+ continue
95
+ if num_frames > max_length:
96
+ continue
97
+ if len(sample['text_token']) < token_min_length:
98
+ continue
99
+ if len(sample['text_token']) > token_max_length:
100
+ continue
101
+ if len(sample['speech_token']) == 0:
102
+ continue
103
+ if num_frames != 0:
104
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
105
+ continue
106
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
107
+ continue
108
+ yield sample
109
+
110
+
111
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
112
+ """ Resample data.
113
+ Inplace operation.
114
+
115
+ Args:
116
+ data: Iterable[{key, wav, label, sample_rate}]
117
+ resample_rate: target resample rate
118
+
119
+ Returns:
120
+ Iterable[{key, wav, label, sample_rate}]
121
+ """
122
+ for sample in data:
123
+ assert 'sample_rate' in sample
124
+ assert 'speech' in sample
125
+ sample_rate = sample['sample_rate']
126
+ waveform = sample['speech']
127
+ if sample_rate != resample_rate:
128
+ if sample_rate < min_sample_rate:
129
+ continue
130
+ sample['sample_rate'] = resample_rate
131
+ sample['speech'] = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
133
+ max_val = sample['speech'].abs().max()
134
+ if max_val > 1:
135
+ sample['speech'] /= max_val
136
+ yield sample
137
+
138
+
139
+ def truncate(data, truncate_length=24576, mode='train'):
140
+ """ Truncate data.
141
+
142
+ Args:
143
+ data: Iterable[{key, wav, label, sample_rate}]
144
+ truncate_length: truncate length
145
+
146
+ Returns:
147
+ Iterable[{key, wav, label, sample_rate}]
148
+ """
149
+ for sample in data:
150
+ waveform = sample['speech']
151
+ if waveform.shape[1] > truncate_length:
152
+ start = random.randint(0, waveform.shape[1] - truncate_length)
153
+ waveform = waveform[:, start: start + truncate_length]
154
+ else:
155
+ waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
156
+ sample['speech'] = waveform
157
+ yield sample
158
+
159
+
160
+ def compute_fbank(data,
161
+ feat_extractor,
162
+ mode='train'):
163
+ """ Extract fbank
164
+
165
+ Args:
166
+ data: Iterable[{key, wav, label, sample_rate}]
167
+
168
+ Returns:
169
+ Iterable[{key, feat, label}]
170
+ """
171
+ for sample in data:
172
+ assert 'sample_rate' in sample
173
+ assert 'speech' in sample
174
+ assert 'utt' in sample
175
+ assert 'text_token' in sample
176
+ waveform = sample['speech']
177
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
178
+ sample['speech_feat'] = mat
179
+ yield sample
180
+
181
+
182
+ def compute_f0(data, pitch_extractor, mode='train'):
183
+ """ Extract f0
184
+
185
+ Args:
186
+ data: Iterable[{key, wav, label, sample_rate}]
187
+
188
+ Returns:
189
+ Iterable[{key, feat, label}]
190
+ """
191
+ for sample in data:
192
+ assert 'sample_rate' in sample
193
+ assert 'speech' in sample
194
+ assert 'utt' in sample
195
+ assert 'text_token' in sample
196
+ waveform = sample['speech']
197
+ mat = pitch_extractor(waveform).transpose(1, 2)
198
+ mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
199
+ sample['pitch_feat'] = mat[0, 0]
200
+ yield sample
201
+
202
+
203
+ def parse_embedding(data, normalize, mode='train'):
204
+ """ Parse utt_embedding/spk_embedding
205
+
206
+ Args:
207
+ data: Iterable[{key, wav, label, sample_rate}]
208
+
209
+ Returns:
210
+ Iterable[{key, feat, label}]
211
+ """
212
+ for sample in data:
213
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
214
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
215
+ if normalize:
216
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
217
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
218
+ yield sample
219
+
220
+
221
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
222
+ """ Decode text to chars or BPE
223
+ Inplace operation
224
+
225
+ Args:
226
+ data: Iterable[{key, wav, txt, sample_rate}]
227
+
228
+ Returns:
229
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
230
+ """
231
+ tokenizer = get_tokenizer()
232
+ for sample in data:
233
+ assert 'text' in sample
234
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
235
+ if mode == 'inference':
236
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
237
+ yield sample
238
+
239
+
240
+ def shuffle(data, shuffle_size=10000, mode='train'):
241
+ """ Local shuffle the data
242
+
243
+ Args:
244
+ data: Iterable[{key, feat, label}]
245
+ shuffle_size: buffer size for shuffle
246
+
247
+ Returns:
248
+ Iterable[{key, feat, label}]
249
+ """
250
+ buf = []
251
+ for sample in data:
252
+ buf.append(sample)
253
+ if len(buf) >= shuffle_size:
254
+ random.shuffle(buf)
255
+ for x in buf:
256
+ yield x
257
+ buf = []
258
+ # The sample left over
259
+ random.shuffle(buf)
260
+ for x in buf:
261
+ yield x
262
+
263
+
264
+ def sort(data, sort_size=500, mode='train'):
265
+ """ Sort the data by feature length.
266
+ Sort is used after shuffle and before batch, so we can group
267
+ utts with similar lengths into a batch, and `sort_size` should
268
+ be less than `shuffle_size`
269
+
270
+ Args:
271
+ data: Iterable[{key, feat, label}]
272
+ sort_size: buffer size for sort
273
+
274
+ Returns:
275
+ Iterable[{key, feat, label}]
276
+ """
277
+
278
+ buf = []
279
+ for sample in data:
280
+ buf.append(sample)
281
+ if len(buf) >= sort_size:
282
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
283
+ for x in buf:
284
+ yield x
285
+ buf = []
286
+ # The sample left over
287
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
288
+ for x in buf:
289
+ yield x
290
+
291
+
292
+ def static_batch(data, batch_size=16):
293
+ """ Static batch the data by `batch_size`
294
+
295
+ Args:
296
+ data: Iterable[{key, feat, label}]
297
+ batch_size: batch size
298
+
299
+ Returns:
300
+ Iterable[List[{key, feat, label}]]
301
+ """
302
+ buf = []
303
+ for sample in data:
304
+ buf.append(sample)
305
+ if len(buf) >= batch_size:
306
+ yield buf
307
+ buf = []
308
+ if len(buf) > 0:
309
+ yield buf
310
+
311
+
312
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
313
+ """ Dynamic batch the data until the total frames in batch
314
+ reach `max_frames_in_batch`
315
+
316
+ Args:
317
+ data: Iterable[{key, feat, label}]
318
+ max_frames_in_batch: max_frames in one batch
319
+
320
+ Returns:
321
+ Iterable[List[{key, feat, label}]]
322
+ """
323
+ buf = []
324
+ longest_frames = 0
325
+ for sample in data:
326
+ assert 'speech_feat' in sample
327
+ assert isinstance(sample['speech_feat'], torch.Tensor)
328
+ new_sample_frames = sample['speech_feat'].size(0)
329
+ longest_frames = max(longest_frames, new_sample_frames)
330
+ frames_after_padding = longest_frames * (len(buf) + 1)
331
+ if frames_after_padding > max_frames_in_batch:
332
+ yield buf
333
+ buf = [sample]
334
+ longest_frames = new_sample_frames
335
+ else:
336
+ buf.append(sample)
337
+ if len(buf) > 0:
338
+ yield buf
339
+
340
+
341
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
342
+ """ Wrapper for static/dynamic batch
343
+ """
344
+ if mode == 'inference':
345
+ return static_batch(data, 1)
346
+ else:
347
+ if batch_type == 'static':
348
+ return static_batch(data, batch_size)
349
+ elif batch_type == 'dynamic':
350
+ return dynamic_batch(data, max_frames_in_batch)
351
+ else:
352
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
353
+
354
+
355
+ def padding(data, use_spk_embedding, mode='train', gan=False):
356
+ """ Padding the data into training data
357
+
358
+ Args:
359
+ data: Iterable[List[{key, feat, label}]]
360
+
361
+ Returns:
362
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
363
+ """
364
+ for sample in data:
365
+ assert isinstance(sample, list)
366
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
367
+ dtype=torch.int32)
368
+ order = torch.argsort(speech_feat_len, descending=True)
369
+
370
+ utts = [sample[i]['utt'] for i in order]
371
+ speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
372
+ speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
373
+ speech = pad_sequence(speech, batch_first=True, padding_value=0)
374
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
375
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
376
+ speech_token = pad_sequence(speech_token,
377
+ batch_first=True,
378
+ padding_value=0)
379
+ speech_feat = [sample[i]['speech_feat'] for i in order]
380
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
381
+ speech_feat = pad_sequence(speech_feat,
382
+ batch_first=True,
383
+ padding_value=0)
384
+ text = [sample[i]['text'] for i in order]
385
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
386
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
387
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
388
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
389
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
390
+ batch = {
391
+ "utts": utts,
392
+ "speech": speech,
393
+ "speech_len": speech_len,
394
+ "speech_token": speech_token,
395
+ "speech_token_len": speech_token_len,
396
+ "speech_feat": speech_feat,
397
+ "speech_feat_len": speech_feat_len,
398
+ "text": text,
399
+ "text_token": text_token,
400
+ "text_token_len": text_token_len,
401
+ "utt_embedding": utt_embedding,
402
+ "spk_embedding": spk_embedding,
403
+ }
404
+ if gan is True:
405
+ # in gan train, we need pitch_feat
406
+ pitch_feat = [sample[i]['pitch_feat'] for i in order]
407
+ pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
408
+ pitch_feat = pad_sequence(pitch_feat,
409
+ batch_first=True,
410
+ padding_value=0)
411
+ batch["pitch_feat"] = pitch_feat
412
+ batch["pitch_feat_len"] = pitch_feat_len
413
+ else:
414
+ # only gan train needs speech, delete it to save memory
415
+ del batch["speech"]
416
+ del batch["speech_len"]
417
+ if mode == 'inference':
418
+ tts_text = [sample[i]['tts_text'] for i in order]
419
+ tts_index = [sample[i]['tts_index'] for i in order]
420
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
421
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
422
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
423
+ batch.update({'tts_text': tts_text,
424
+ 'tts_index': tts_index,
425
+ 'tts_text_token': tts_text_token,
426
+ 'tts_text_token_len': tts_text_token_len})
427
+ if use_spk_embedding is True:
428
+ batch["embedding"] = batch["spk_embedding"]
429
+ else:
430
+ batch["embedding"] = batch["utt_embedding"]
431
+ yield batch
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
20
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
+ from matcha.models.components.transformer import BasicTransformerBlock
22
+
23
+
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype)
78
+ assert stride == 1
79
+ self.causal_padding = (kernel_size - 1, 0)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ x = F.pad(x, self.causal_padding)
83
+ x = super(CausalConv1d, self).forward(x)
84
+ return x
85
+
86
+
87
+ class ConditionalDecoder(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ out_channels,
92
+ causal=False,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.causal = causal
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
+ transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ dim=output_channel,
132
+ num_attention_heads=num_heads,
133
+ attention_head_dim=attention_head_dim,
134
+ dropout=dropout,
135
+ activation_fn=act_fn,
136
+ )
137
+ for _ in range(n_blocks)
138
+ ]
139
+ )
140
+ downsample = (
141
+ Downsample1D(output_channel) if not is_last else
142
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
143
+ )
144
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
145
+
146
+ for _ in range(num_mid_blocks):
147
+ input_channel = channels[-1]
148
+ out_channels = channels[-1]
149
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
151
+
152
+ transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ dim=output_channel,
156
+ num_attention_heads=num_heads,
157
+ attention_head_dim=attention_head_dim,
158
+ dropout=dropout,
159
+ activation_fn=act_fn,
160
+ )
161
+ for _ in range(n_blocks)
162
+ ]
163
+ )
164
+
165
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
166
+
167
+ channels = channels[::-1] + (channels[0],)
168
+ for i in range(len(channels) - 1):
169
+ input_channel = channels[i] * 2
170
+ output_channel = channels[i + 1]
171
+ is_last = i == len(channels) - 2
172
+ resnet = CausalResnetBlock1D(
173
+ dim=input_channel,
174
+ dim_out=output_channel,
175
+ time_emb_dim=time_embed_dim,
176
+ ) if self.causal else ResnetBlock1D(
177
+ dim=input_channel,
178
+ dim_out=output_channel,
179
+ time_emb_dim=time_embed_dim,
180
+ )
181
+ transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicTransformerBlock(
184
+ dim=output_channel,
185
+ num_attention_heads=num_heads,
186
+ attention_head_dim=attention_head_dim,
187
+ dropout=dropout,
188
+ activation_fn=act_fn,
189
+ )
190
+ for _ in range(n_blocks)
191
+ ]
192
+ )
193
+ upsample = (
194
+ Upsample1D(output_channel, use_conv_transpose=True)
195
+ if not is_last
196
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
197
+ )
198
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
199
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
200
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
201
+ self.initialize_weights()
202
+
203
+ def initialize_weights(self):
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv1d):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.GroupNorm):
210
+ nn.init.constant_(m.weight, 1)
211
+ nn.init.constant_(m.bias, 0)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
214
+ if m.bias is not None:
215
+ nn.init.constant_(m.bias, 0)
216
+
217
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
218
+ """Forward pass of the UNet1DConditional model.
219
+
220
+ Args:
221
+ x (torch.Tensor): shape (batch_size, in_channels, time)
222
+ mask (_type_): shape (batch_size, 1, time)
223
+ t (_type_): shape (batch_size)
224
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
225
+ cond (_type_, optional): placeholder for future use. Defaults to None.
226
+
227
+ Raises:
228
+ ValueError: _description_
229
+ ValueError: _description_
230
+
231
+ Returns:
232
+ _type_: _description_
233
+ """
234
+
235
+ t = self.time_embeddings(t).to(t.dtype)
236
+ t = self.time_mlp(t)
237
+
238
+ x = pack([x, mu], "b * t")[0]
239
+
240
+ if spks is not None:
241
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
242
+ x = pack([x, spks], "b * t")[0]
243
+ if cond is not None:
244
+ x = pack([x, cond], "b * t")[0]
245
+
246
+ hiddens = []
247
+ masks = [mask]
248
+ for resnet, transformer_blocks, downsample in self.down_blocks:
249
+ mask_down = masks[-1]
250
+ x = resnet(x, mask_down, t)
251
+ x = rearrange(x, "b c t -> b t c").contiguous()
252
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
255
+ for transformer_block in transformer_blocks:
256
+ x = transformer_block(
257
+ hidden_states=x,
258
+ attention_mask=attn_mask,
259
+ timestep=t,
260
+ )
261
+ x = rearrange(x, "b t c -> b c t").contiguous()
262
+ hiddens.append(x) # Save hidden states for skip connections
263
+ x = downsample(x * mask_down)
264
+ masks.append(mask_down[:, :, ::2])
265
+ masks = masks[:-1]
266
+ mask_mid = masks[-1]
267
+
268
+ for resnet, transformer_blocks in self.mid_blocks:
269
+ x = resnet(x, mask_mid, t)
270
+ x = rearrange(x, "b c t -> b t c").contiguous()
271
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
274
+ for transformer_block in transformer_blocks:
275
+ x = transformer_block(
276
+ hidden_states=x,
277
+ attention_mask=attn_mask,
278
+ timestep=t,
279
+ )
280
+ x = rearrange(x, "b t c -> b c t").contiguous()
281
+
282
+ for resnet, transformer_blocks, upsample in self.up_blocks:
283
+ mask_up = masks.pop()
284
+ skip = hiddens.pop()
285
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
286
+ x = resnet(x, mask_up, t)
287
+ x = rearrange(x, "b c t -> b t c").contiguous()
288
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
291
+ for transformer_block in transformer_blocks:
292
+ x = transformer_block(
293
+ hidden_states=x,
294
+ attention_mask=attn_mask,
295
+ timestep=t,
296
+ )
297
+ x = rearrange(x, "b t c -> b c t").contiguous()
298
+ x = upsample(x * mask_up)
299
+ x = self.final_block(x, mask_up)
300
+ output = self.final_proj(x * mask_up)
301
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ # Clamp tokens to valid vocabulary range
78
+ token = torch.clamp(token, min=0, max=self.vocab_size - 1)
79
+ token = self.input_embedding(token) * mask
80
+
81
+ # text encode
82
+ h, h_lengths = self.encoder(token, token_len)
83
+ h = self.encoder_proj(h)
84
+ h, h_lengths = self.length_regulator(h, feat_len)
85
+
86
+ # get conditions
87
+ conds = torch.zeros(feat.shape, device=token.device)
88
+ for i, j in enumerate(feat_len):
89
+ if random.random() < 0.5:
90
+ continue
91
+ index = random.randint(0, int(0.3 * j))
92
+ conds[i, :index] = feat[i, :index]
93
+ conds = conds.transpose(1, 2)
94
+
95
+ mask = (~make_pad_mask(feat_len)).to(h)
96
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
97
+ loss, _ = self.decoder.compute_loss(
98
+ feat.transpose(1, 2).contiguous(),
99
+ mask.unsqueeze(1),
100
+ h.transpose(1, 2).contiguous(),
101
+ embedding,
102
+ cond=conds
103
+ )
104
+ return {'loss': loss}
105
+
106
+ @torch.inference_mode()
107
+ def inference(self,
108
+ token,
109
+ token_len,
110
+ prompt_token,
111
+ prompt_token_len,
112
+ prompt_feat,
113
+ prompt_feat_len,
114
+ embedding,
115
+ flow_cache):
116
+ assert token.shape[0] == 1
117
+ # xvec projection
118
+ embedding = F.normalize(embedding, dim=1)
119
+ embedding = self.spk_embed_affine_layer(embedding)
120
+
121
+ # concat text and prompt_text
122
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
123
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
124
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
125
+ # Clamp tokens to valid vocabulary range
126
+ token = torch.clamp(token, min=0, max=self.vocab_size - 1)
127
+ token = self.input_embedding(token) * mask
128
+
129
+ # text encode
130
+ h, h_lengths = self.encoder(token, token_len)
131
+ h = self.encoder_proj(h)
132
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
133
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
134
+
135
+ # get conditions
136
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
137
+ conds[:, :mel_len1] = prompt_feat
138
+ conds = conds.transpose(1, 2)
139
+
140
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
141
+ feat, flow_cache = self.decoder(
142
+ mu=h.transpose(1, 2).contiguous(),
143
+ mask=mask.unsqueeze(1),
144
+ spks=embedding,
145
+ cond=conds,
146
+ n_timesteps=10,
147
+ prompt_len=mel_len1,
148
+ flow_cache=flow_cache
149
+ )
150
+ feat = feat[:, :, mel_len1:]
151
+ assert feat.shape[2] == mel_len2
152
+ return feat, flow_cache
153
+
154
+
155
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
156
+ def __init__(self,
157
+ input_size: int = 512,
158
+ output_size: int = 80,
159
+ spk_embed_dim: int = 192,
160
+ output_type: str = "mel",
161
+ vocab_size: int = 4096,
162
+ input_frame_rate: int = 50,
163
+ only_mask_loss: bool = True,
164
+ token_mel_ratio: int = 2,
165
+ pre_lookahead_len: int = 3,
166
+ encoder: torch.nn.Module = None,
167
+ decoder: torch.nn.Module = None,
168
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
175
+ super().__init__()
176
+ self.input_size = input_size
177
+ self.output_size = output_size
178
+ self.decoder_conf = decoder_conf
179
+ self.mel_feat_conf = mel_feat_conf
180
+ self.vocab_size = vocab_size
181
+ self.output_type = output_type
182
+ self.input_frame_rate = input_frame_rate
183
+ logging.info(f"input frame rate={self.input_frame_rate}")
184
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
185
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
186
+ self.encoder = encoder
187
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
188
+ self.decoder = decoder
189
+ self.only_mask_loss = only_mask_loss
190
+ self.token_mel_ratio = token_mel_ratio
191
+ self.pre_lookahead_len = pre_lookahead_len
192
+
193
+ @torch.inference_mode()
194
+ def inference(self,
195
+ token,
196
+ token_len,
197
+ prompt_token,
198
+ prompt_token_len,
199
+ prompt_feat,
200
+ prompt_feat_len,
201
+ embedding,
202
+ finalize):
203
+ assert token.shape[0] == 1
204
+ # xvec projection
205
+ embedding = F.normalize(embedding, dim=1)
206
+ embedding = self.spk_embed_affine_layer(embedding)
207
+
208
+ # concat text and prompt_text
209
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
210
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
211
+ # Clamp tokens to valid vocabulary range
212
+ token = torch.clamp(token, min=0, max=self.vocab_size - 1)
213
+ token = self.input_embedding(token) * mask
214
+
215
+ # text encode
216
+ h, h_lengths = self.encoder(token, token_len)
217
+ if finalize is False:
218
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
219
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
220
+ h = self.encoder_proj(h)
221
+
222
+ # get conditions
223
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
224
+ conds[:, :mel_len1] = prompt_feat
225
+ conds = conds.transpose(1, 2)
226
+
227
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
228
+ feat, _ = self.decoder(
229
+ mu=h.transpose(1, 2).contiguous(),
230
+ mask=mask.unsqueeze(1),
231
+ spks=embedding,
232
+ cond=conds,
233
+ n_timesteps=10
234
+ )
235
+ feat = feat[:, :, mel_len1:]
236
+ assert feat.shape[2] == mel_len2
237
+ return feat, None
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import onnxruntime
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
37
+ """Forward diffusion
38
+
39
+ Args:
40
+ mu (torch.Tensor): output of encoder
41
+ shape: (batch_size, n_feats, mel_timesteps)
42
+ mask (torch.Tensor): output_mask
43
+ shape: (batch_size, 1, mel_timesteps)
44
+ n_timesteps (int): number of diffusion steps
45
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
46
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
47
+ shape: (batch_size, spk_emb_dim)
48
+ cond: Not used but kept for future purposes
49
+
50
+ Returns:
51
+ sample: generated mel-spectrogram
52
+ shape: (batch_size, n_feats, mel_timesteps)
53
+ """
54
+
55
+ z = torch.randn_like(mu) * temperature
56
+ # Handle None flow_cache
57
+ if flow_cache is not None:
58
+ cache_size = flow_cache.shape[2]
59
+ # fix prompt and overlap part mu and z
60
+ if cache_size != 0:
61
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
62
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
63
+ else:
64
+ cache_size = 0
65
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
66
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
67
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
68
+
69
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
70
+ if self.t_scheduler == 'cosine':
71
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
72
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
73
+
74
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
75
+ """
76
+ Fixed euler solver for ODEs.
77
+ Args:
78
+ x (torch.Tensor): random noise
79
+ t_span (torch.Tensor): n_timesteps interpolated
80
+ shape: (n_timesteps + 1,)
81
+ mu (torch.Tensor): output of encoder
82
+ shape: (batch_size, n_feats, mel_timesteps)
83
+ mask (torch.Tensor): output_mask
84
+ shape: (batch_size, 1, mel_timesteps)
85
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
86
+ shape: (batch_size, spk_emb_dim)
87
+ cond: Not used but kept for future purposes
88
+ """
89
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
90
+ t = t.unsqueeze(dim=0)
91
+
92
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
93
+ # Or in future might add like a return_all_steps flag
94
+ sol = []
95
+
96
+ if self.inference_cfg_rate > 0:
97
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
98
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
99
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
100
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
101
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
102
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
103
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
104
+ else:
105
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
106
+ for step in range(1, len(t_span)):
107
+ # Classifier-Free Guidance inference introduced in VoiceBox
108
+ if self.inference_cfg_rate > 0:
109
+ x_in[:] = x
110
+ mask_in[:] = mask
111
+ mu_in[0] = mu
112
+ t_in[:] = t.unsqueeze(0)
113
+ spks_in[0] = spks
114
+ cond_in[0] = cond
115
+ else:
116
+ x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
117
+ dphi_dt = self.forward_estimator(
118
+ x_in, mask_in,
119
+ mu_in, t_in,
120
+ spks_in,
121
+ cond_in
122
+ )
123
+ if self.inference_cfg_rate > 0:
124
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
125
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
126
+ x = x + dt * dphi_dt
127
+ t = t + dt
128
+ sol.append(x)
129
+ if step < len(t_span) - 1:
130
+ dt = t_span[step + 1] - t
131
+
132
+ return sol[-1].float()
133
+
134
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
135
+ if isinstance(self.estimator, torch.nn.Module):
136
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
137
+ elif isinstance(self.estimator, onnxruntime.InferenceSession):
138
+ ort_inputs = {
139
+ 'x': x.cpu().numpy(),
140
+ 'mask': mask.cpu().numpy(),
141
+ 'mu': mu.cpu().numpy(),
142
+ 't': t.cpu().numpy(),
143
+ 'spks': spks.cpu().numpy(),
144
+ 'cond': cond.cpu().numpy()
145
+ }
146
+ output = self.estimator.run(None, ort_inputs)[0]
147
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
148
+ else:
149
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
150
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
151
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
152
+ self.estimator.set_input_shape('t', (2,))
153
+ self.estimator.set_input_shape('spks', (2, 80))
154
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
155
+ # run trt engine
156
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
157
+ mask.contiguous().data_ptr(),
158
+ mu.contiguous().data_ptr(),
159
+ t.contiguous().data_ptr(),
160
+ spks.contiguous().data_ptr(),
161
+ cond.contiguous().data_ptr(),
162
+ x.data_ptr()])
163
+ return x
164
+
165
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
166
+ """Computes diffusion loss
167
+
168
+ Args:
169
+ x1 (torch.Tensor): Target
170
+ shape: (batch_size, n_feats, mel_timesteps)
171
+ mask (torch.Tensor): target mask
172
+ shape: (batch_size, 1, mel_timesteps)
173
+ mu (torch.Tensor): output of encoder
174
+ shape: (batch_size, n_feats, mel_timesteps)
175
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
176
+ shape: (batch_size, spk_emb_dim)
177
+
178
+ Returns:
179
+ loss: conditional flow matching loss
180
+ y: conditional flow
181
+ shape: (batch_size, n_feats, mel_timesteps)
182
+ """
183
+ b, _, t = mu.shape
184
+
185
+ # random timestep
186
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
187
+ if self.t_scheduler == 'cosine':
188
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
189
+ # sample noise p(x_0)
190
+ z = torch.randn_like(x1)
191
+
192
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
193
+ u = x1 - (1 - self.sigma_min) * z
194
+
195
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
196
+ if self.training_cfg_rate > 0:
197
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
198
+ mu = mu * cfg_mask.view(-1, 1, 1)
199
+ spks = spks * cfg_mask.view(-1, 1)
200
+ cond = cond * cfg_mask.view(-1, 1, 1)
201
+
202
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
203
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
204
+ return loss, y
205
+
206
+
207
+ class CausalConditionalCFM(ConditionalCFM):
208
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
209
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
210
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
211
+
212
+ @torch.inference_mode()
213
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
214
+ """Forward diffusion
215
+
216
+ Args:
217
+ mu (torch.Tensor): output of encoder
218
+ shape: (batch_size, n_feats, mel_timesteps)
219
+ mask (torch.Tensor): output_mask
220
+ shape: (batch_size, 1, mel_timesteps)
221
+ n_timesteps (int): number of diffusion steps
222
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
223
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
224
+ shape: (batch_size, spk_emb_dim)
225
+ cond: Not used but kept for future purposes
226
+
227
+ Returns:
228
+ sample: generated mel-spectrogram
229
+ shape: (batch_size, n_feats, mel_timesteps)
230
+ """
231
+
232
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
233
+ if self.fp16 is True:
234
+ z = z.half()
235
+ # fix prompt and overlap part mu and z
236
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
237
+ if self.t_scheduler == 'cosine':
238
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
239
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
+ else:
62
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63
+ if x1.shape[1] != 0:
64
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65
+ x = torch.concat([x1, x2], dim=2)
66
+ else:
67
+ x = x2
68
+ out = self.model(x).transpose(1, 2).contiguous()
69
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import weight_norm
4
+ from typing import List, Optional, Tuple
5
+ from einops import rearrange
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultipleDiscriminator(nn.Module):
10
+ def __init__(
11
+ self, mpd: nn.Module, mrd: nn.Module
12
+ ):
13
+ super().__init__()
14
+ self.mpd = mpd
15
+ self.mrd = mrd
16
+
17
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
18
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
19
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
20
+ y_d_rs += this_y_d_rs
21
+ y_d_gs += this_y_d_gs
22
+ fmap_rs += this_fmap_rs
23
+ fmap_gs += this_fmap_gs
24
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
25
+ y_d_rs += this_y_d_rs
26
+ y_d_gs += this_y_d_gs
27
+ fmap_rs += this_fmap_rs
28
+ fmap_gs += this_fmap_gs
29
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30
+
31
+
32
+ class MultiResolutionDiscriminator(nn.Module):
33
+ def __init__(
34
+ self,
35
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
36
+ num_embeddings: Optional[int] = None,
37
+ ):
38
+ """
39
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
40
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
41
+
42
+ Args:
43
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
44
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
45
+ Defaults to None.
46
+ """
47
+
48
+ super().__init__()
49
+ self.discriminators = nn.ModuleList(
50
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
51
+ )
52
+
53
+ def forward(
54
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
55
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
56
+ y_d_rs = []
57
+ y_d_gs = []
58
+ fmap_rs = []
59
+ fmap_gs = []
60
+
61
+ for d in self.discriminators:
62
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
63
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
64
+ y_d_rs.append(y_d_r)
65
+ fmap_rs.append(fmap_r)
66
+ y_d_gs.append(y_d_g)
67
+ fmap_gs.append(fmap_g)
68
+
69
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
70
+
71
+
72
+ class DiscriminatorR(nn.Module):
73
+ def __init__(
74
+ self,
75
+ window_length: int,
76
+ num_embeddings: Optional[int] = None,
77
+ channels: int = 32,
78
+ hop_factor: float = 0.25,
79
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
80
+ ):
81
+ super().__init__()
82
+ self.window_length = window_length
83
+ self.hop_factor = hop_factor
84
+ self.spec_fn = Spectrogram(
85
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
86
+ )
87
+ n_fft = window_length // 2 + 1
88
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
89
+ self.bands = bands
90
+ convs = lambda: nn.ModuleList(
91
+ [
92
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
93
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
94
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
95
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
96
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
97
+ ]
98
+ )
99
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
100
+
101
+ if num_embeddings is not None:
102
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
103
+ torch.nn.init.zeros_(self.emb.weight)
104
+
105
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
106
+
107
+ def spectrogram(self, x):
108
+ # Remove DC offset
109
+ x = x - x.mean(dim=-1, keepdims=True)
110
+ # Peak normalize the volume of input audio
111
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
112
+ x = self.spec_fn(x)
113
+ x = torch.view_as_real(x)
114
+ x = rearrange(x, "b f t c -> b c t f")
115
+ # Split into bands
116
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
117
+ return x_bands
118
+
119
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
120
+ x_bands = self.spectrogram(x)
121
+ fmap = []
122
+ x = []
123
+ for band, stack in zip(x_bands, self.band_convs):
124
+ for i, layer in enumerate(stack):
125
+ band = layer(band)
126
+ band = torch.nn.functional.leaky_relu(band, 0.1)
127
+ if i > 0:
128
+ fmap.append(band)
129
+ x.append(band)
130
+ x = torch.cat(x, dim=-1)
131
+ if cond_embedding_id is not None:
132
+ emb = self.emb(cond_embedding_id)
133
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
134
+ else:
135
+ h = 0
136
+ x = self.conv_post(x)
137
+ fmap.append(x)
138
+ x += h
139
+
140
+ return x, fmap