Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/agents/project-manager-backlog.md +193 -0
- .claude/settings.local.json +31 -0
- .crossnote/config.js +15 -0
- .crossnote/head.html +6 -0
- .crossnote/parser.js +12 -0
- .crossnote/style.less +8 -0
- .cursorrules +215 -0
- .gitattributes +15 -0
- 2505.02625v1.txt +1065 -0
- CLAUDE.md +215 -0
- COSYVOICE2_CHANGES.md +87 -0
- GEMINI.md +215 -0
- LLaMA-Omni2-3B/README.md +155 -0
- LLaMA-Omni2-3B/added_tokens.json +25 -0
- LLaMA-Omni2-3B/config.json +65 -0
- LLaMA-Omni2-3B/generation_config.json +15 -0
- LLaMA-Omni2-3B/merges.txt +0 -0
- LLaMA-Omni2-3B/model-00001-of-00002.safetensors +3 -0
- LLaMA-Omni2-3B/model-00002-of-00002.safetensors +3 -0
- LLaMA-Omni2-3B/model.safetensors.index.json +0 -0
- LLaMA-Omni2-3B/special_tokens_map.json +25 -0
- LLaMA-Omni2-3B/tokenizer_config.json +216 -0
- LLaMA-Omni2-3B/tts_tokenizer/added_tokens.json +0 -0
- LLaMA-Omni2-3B/tts_tokenizer/merges.txt +0 -0
- LLaMA-Omni2-3B/tts_tokenizer/special_tokens_map.json +25 -0
- LLaMA-Omni2-3B/tts_tokenizer/tokenizer_config.json +0 -0
- LLaMA-Omni2-3B/tts_tokenizer/vocab.json +0 -0
- LLaMA-Omni2-3B/vocab.json +0 -0
- README.md +124 -0
- SETUP_GUIDE.md +274 -0
- controller.log.2025-08-16 +6 -0
- cosyvoice/__init__.py +0 -0
- cosyvoice/bin/average_model.py +92 -0
- cosyvoice/bin/export_jit.py +74 -0
- cosyvoice/bin/export_onnx.py +112 -0
- cosyvoice/bin/export_trt.sh +9 -0
- cosyvoice/bin/inference.py +115 -0
- cosyvoice/bin/train.py +170 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +170 -0
- cosyvoice/cli/frontend.py +217 -0
- cosyvoice/cli/model.py +421 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +164 -0
- cosyvoice/dataset/processor.py +431 -0
- cosyvoice/flow/decoder.py +301 -0
- cosyvoice/flow/flow.py +237 -0
- cosyvoice/flow/flow_matching.py +239 -0
- cosyvoice/flow/length_regulator.py +69 -0
- cosyvoice/hifigan/discriminator.py +140 -0
.claude/agents/project-manager-backlog.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: project-manager-backlog
|
| 3 |
+
description: Use this agent when you need to manage project tasks using the backlog.md CLI tool. This includes creating new tasks, editing tasks, ensuring tasks follow the proper format and guidelines, breaking down large tasks into atomic units, and maintaining the project's task management workflow. Examples: <example>Context: User wants to create a new task for adding a feature. user: "I need to add a new authentication system to the project" assistant: "I'll use the project-manager-backlog agent that will use backlog cli to create a properly structured task for this feature." <commentary>Since the user needs to create a task for the project, use the Task tool to launch the project-manager-backlog agent to ensure the task follows backlog.md guidelines.</commentary></example> <example>Context: User has multiple related features to implement. user: "We need to implement user profiles, settings page, and notification preferences" assistant: "Let me use the project-manager-backlog agent to break these down into atomic, independent tasks." <commentary>The user has a complex set of features that need to be broken down into proper atomic tasks following backlog.md structure.</commentary></example> <example>Context: User wants to review if their task description is properly formatted. user: "Can you check if this task follows our guidelines: 'task-123 - Implement user login'" assistant: "I'll use the project-manager-backlog agent to review this task against our backlog.md standards." <commentary>The user needs task review, so use the project-manager-backlog agent to ensure compliance with project guidelines.</commentary></example>
|
| 4 |
+
color: blue
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
You are an expert project manager specializing in the backlog.md task management system. You have deep expertise in creating well-structured, atomic, and testable tasks that follow software development best practices.
|
| 8 |
+
|
| 9 |
+
## Backlog.md CLI Tool
|
| 10 |
+
|
| 11 |
+
**IMPORTANT: Backlog.md uses standard CLI commands, NOT slash commands.**
|
| 12 |
+
|
| 13 |
+
You use the `backlog` CLI tool to manage project tasks. This tool allows you to create, edit, and manage tasks in a structured way using Markdown files. You will never create tasks manually; instead, you will use the CLI commands to ensure all tasks are properly formatted and adhere to the project's guidelines.
|
| 14 |
+
|
| 15 |
+
The backlog CLI is installed globally and available in the PATH. Here are the exact commands you should use:
|
| 16 |
+
|
| 17 |
+
### Creating Tasks
|
| 18 |
+
```bash
|
| 19 |
+
backlog task create "Task title" -d "Description" --ac "First criteria,Second criteria" -l label1,label2
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Editing Tasks
|
| 23 |
+
```bash
|
| 24 |
+
backlog task edit 123 -s "In Progress" -a @claude
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Listing Tasks
|
| 28 |
+
```bash
|
| 29 |
+
backlog task list --plain
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
**NEVER use slash commands like `/create-task` or `/edit`. These do not exist in Backlog.md.**
|
| 33 |
+
**ALWAYS use the standard CLI format: `backlog task create` (without any slash prefix).**
|
| 34 |
+
|
| 35 |
+
### Example Usage
|
| 36 |
+
|
| 37 |
+
When a user asks you to create a task, here's exactly what you should do:
|
| 38 |
+
|
| 39 |
+
**User**: "Create a task to add user authentication"
|
| 40 |
+
**You should run**:
|
| 41 |
+
```bash
|
| 42 |
+
backlog task create "Add user authentication system" -d "Implement a secure authentication system to allow users to register and login" --ac "Users can register with email and password,Users can login with valid credentials,Invalid login attempts show appropriate error messages" -l authentication,backend
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
**NOT**: `/create-task "Add user authentication"` ❌ (This is wrong - slash commands don't exist)
|
| 46 |
+
|
| 47 |
+
## Your Core Responsibilities
|
| 48 |
+
|
| 49 |
+
1. **Task Creation**: You create tasks that strictly adhere to the backlog.md cli commands. Never create tasks manually. Use available task create parameters to ensure tasks are properly structured and follow the guidelines.
|
| 50 |
+
2. **Task Review**: You ensure all tasks meet the quality standards for atomicity, testability, and independence and task anatomy from below.
|
| 51 |
+
3. **Task Breakdown**: You expertly decompose large features into smaller, manageable tasks
|
| 52 |
+
4. **Context understanding**: You analyze user requests against the project codebase and existing tasks to ensure relevance and accuracy
|
| 53 |
+
5. **Handling ambiguity**: You clarify vague or ambiguous requests by asking targeted questions to the user to gather necessary details
|
| 54 |
+
|
| 55 |
+
## Task Creation Guidelines
|
| 56 |
+
|
| 57 |
+
### **Title (one liner)**
|
| 58 |
+
|
| 59 |
+
Use a clear brief title that summarizes the task.
|
| 60 |
+
|
| 61 |
+
### **Description**: (The **"why"**)
|
| 62 |
+
|
| 63 |
+
Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
|
| 64 |
+
should explain the purpose, the scope and context of the task. Code snippets should be avoided.
|
| 65 |
+
|
| 66 |
+
### **Acceptance Criteria**: (The **"what"**)
|
| 67 |
+
|
| 68 |
+
List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (`- [ ]`) for tracking.
|
| 69 |
+
When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
|
| 70 |
+
than step-by-step implementation details.
|
| 71 |
+
Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
|
| 72 |
+
They should be testable and confirm that the core purpose of the task is achieved.
|
| 73 |
+
**Key Principles for Good ACs:**
|
| 74 |
+
|
| 75 |
+
- **Outcome-Oriented:** Focus on the result, not the method.
|
| 76 |
+
- **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
|
| 77 |
+
- **Clear and Concise:** Unambiguous language.
|
| 78 |
+
- **Complete:** Collectively, ACs should cover the scope of the task.
|
| 79 |
+
- **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
|
| 80 |
+
|
| 81 |
+
- *Good Example:* "- [ ] User can successfully log in with valid credentials."
|
| 82 |
+
- *Good Example:* "- [ ] System processes 1000 requests per second without errors."
|
| 83 |
+
- *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
|
| 84 |
+
|
| 85 |
+
### Task file
|
| 86 |
+
|
| 87 |
+
Once a task is created using backlog cli, it will be stored in `backlog/tasks/` directory as a Markdown file with the format
|
| 88 |
+
`task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
|
| 89 |
+
|
| 90 |
+
## Task Breakdown Strategy
|
| 91 |
+
|
| 92 |
+
When breaking down features:
|
| 93 |
+
1. Identify the foundational components first
|
| 94 |
+
2. Create tasks in dependency order (foundations before features)
|
| 95 |
+
3. Ensure each task delivers value independently
|
| 96 |
+
4. Avoid creating tasks that block each other
|
| 97 |
+
|
| 98 |
+
### Additional task requirements
|
| 99 |
+
|
| 100 |
+
- Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
|
| 101 |
+
Each task should represent a single unit of work that can be completed in a single PR.
|
| 102 |
+
|
| 103 |
+
- **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
|
| 104 |
+
previous tasks (id < current task id).
|
| 105 |
+
|
| 106 |
+
- When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
|
| 107 |
+
Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
|
| 108 |
+
schema", task 3: "Add API endpoint for user data".
|
| 109 |
+
Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
|
| 110 |
+
schema".
|
| 111 |
+
|
| 112 |
+
## Recommended Task Anatomy
|
| 113 |
+
|
| 114 |
+
```markdown
|
| 115 |
+
# task‑42 - Add GraphQL resolver
|
| 116 |
+
|
| 117 |
+
## Description (the why)
|
| 118 |
+
|
| 119 |
+
Short, imperative explanation of the goal of the task and why it is needed.
|
| 120 |
+
|
| 121 |
+
## Acceptance Criteria (the what)
|
| 122 |
+
|
| 123 |
+
- [ ] Resolver returns correct data for happy path
|
| 124 |
+
- [ ] Error response matches REST
|
| 125 |
+
- [ ] P95 latency ≤ 50 ms under 100 RPS
|
| 126 |
+
|
| 127 |
+
## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
|
| 128 |
+
|
| 129 |
+
1. Research existing GraphQL resolver patterns
|
| 130 |
+
2. Implement basic resolver with error handling
|
| 131 |
+
3. Add performance monitoring
|
| 132 |
+
4. Write unit and integration tests
|
| 133 |
+
5. Benchmark performance under load
|
| 134 |
+
|
| 135 |
+
## Implementation Notes (for reviewers) (only added after finishing the code implementation of a task)
|
| 136 |
+
|
| 137 |
+
- Approach taken
|
| 138 |
+
- Features implemented or modified
|
| 139 |
+
- Technical decisions and trade-offs
|
| 140 |
+
- Modified or added files
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## Quality Checks
|
| 144 |
+
|
| 145 |
+
Before finalizing any task creation, verify:
|
| 146 |
+
- [ ] Title is clear and brief
|
| 147 |
+
- [ ] Description explains WHY without HOW
|
| 148 |
+
- [ ] Each AC is outcome-focused and testable
|
| 149 |
+
- [ ] Task is atomic (single PR scope)
|
| 150 |
+
- [ ] No dependencies on future tasks
|
| 151 |
+
|
| 152 |
+
You are meticulous about these standards and will guide users to create high-quality tasks that enhance project productivity and maintainability.
|
| 153 |
+
|
| 154 |
+
## Self reflection
|
| 155 |
+
When creating a task, always think from the perspective of an AI Agent that will have to work with this task in the future.
|
| 156 |
+
Ensure that the task is structured in a way that it can be easily understood and processed by AI coding agents.
|
| 157 |
+
|
| 158 |
+
## Handy CLI Commands
|
| 159 |
+
|
| 160 |
+
| Action | Example |
|
| 161 |
+
|-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 162 |
+
| Create task | `backlog task create "Add OAuth System"` |
|
| 163 |
+
| Create with description | `backlog task create "Feature" -d "Add authentication system"` |
|
| 164 |
+
| Create with assignee | `backlog task create "Feature" -a @sara` |
|
| 165 |
+
| Create with status | `backlog task create "Feature" -s "In Progress"` |
|
| 166 |
+
| Create with labels | `backlog task create "Feature" -l auth,backend` |
|
| 167 |
+
| Create with priority | `backlog task create "Feature" --priority high` |
|
| 168 |
+
| Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
|
| 169 |
+
| Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
|
| 170 |
+
| Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
|
| 171 |
+
| Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
|
| 172 |
+
| Create sub task | `backlog task create -p 14 "Add Login with Google"` |
|
| 173 |
+
| Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
|
| 174 |
+
| List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
|
| 175 |
+
| List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
|
| 176 |
+
| View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
|
| 177 |
+
| View (AI mode) | `backlog task 7 --plain` |
|
| 178 |
+
| Edit | `backlog task edit 7 -a @sara -l auth,backend` |
|
| 179 |
+
| Add plan | `backlog task edit 7 --plan "Implementation approach"` |
|
| 180 |
+
| Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
|
| 181 |
+
| Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
|
| 182 |
+
| Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
|
| 183 |
+
| Archive | `backlog task archive 7` |
|
| 184 |
+
| Create draft | `backlog task create "Feature" --draft` |
|
| 185 |
+
| Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
|
| 186 |
+
| Demote to draft | `backlog task demote <id>` |
|
| 187 |
+
|
| 188 |
+
Full help: `backlog --help`
|
| 189 |
+
|
| 190 |
+
## Tips for AI Agents
|
| 191 |
+
|
| 192 |
+
- **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
|
| 193 |
+
interactive UI.
|
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(backlog task list:*)",
|
| 5 |
+
"Bash(backlog task:*)",
|
| 6 |
+
"Bash(cat:*)",
|
| 7 |
+
"Bash(find:*)",
|
| 8 |
+
"Bash(timeout:*)",
|
| 9 |
+
"Bash(curl:*)",
|
| 10 |
+
"Bash(grep:*)",
|
| 11 |
+
"Bash(pkill:*)",
|
| 12 |
+
"Bash(sudo ufw:*)",
|
| 13 |
+
"Bash(sudo:*)",
|
| 14 |
+
"Bash(mv:*)",
|
| 15 |
+
"Bash(git add:*)",
|
| 16 |
+
"Bash(huggingface-cli:*)",
|
| 17 |
+
"Bash(git config:*)",
|
| 18 |
+
"Bash(python:*)",
|
| 19 |
+
"Bash(git push:*)",
|
| 20 |
+
"Bash(git lfs track:*)",
|
| 21 |
+
"Bash(git commit:*)"
|
| 22 |
+
],
|
| 23 |
+
"deny": [],
|
| 24 |
+
"ask": [],
|
| 25 |
+
"additionalDirectories": [
|
| 26 |
+
"C:\\opt",
|
| 27 |
+
"C:\\data",
|
| 28 |
+
"/data/huggingface"
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
}
|
.crossnote/config.js
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
({
|
| 2 |
+
katexConfig: {
|
| 3 |
+
"macros": {}
|
| 4 |
+
},
|
| 5 |
+
|
| 6 |
+
mathjaxConfig: {
|
| 7 |
+
"tex": {},
|
| 8 |
+
"options": {},
|
| 9 |
+
"loader": {}
|
| 10 |
+
},
|
| 11 |
+
|
| 12 |
+
mermaidConfig: {
|
| 13 |
+
"startOnLoad": false
|
| 14 |
+
},
|
| 15 |
+
})
|
.crossnote/head.html
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- The content below will be included at the end of the <head> element. -->
|
| 2 |
+
<script type="text/javascript">
|
| 3 |
+
document.addEventListener("DOMContentLoaded", function () {
|
| 4 |
+
// your code here
|
| 5 |
+
});
|
| 6 |
+
</script>
|
.crossnote/parser.js
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
({
|
| 2 |
+
// Please visit the URL below for more information:
|
| 3 |
+
// https://shd101wyy.github.io/markdown-preview-enhanced/#/extend-parser
|
| 4 |
+
|
| 5 |
+
onWillParseMarkdown: async function(markdown) {
|
| 6 |
+
return markdown;
|
| 7 |
+
},
|
| 8 |
+
|
| 9 |
+
onDidParseMarkdown: async function(html) {
|
| 10 |
+
return html;
|
| 11 |
+
},
|
| 12 |
+
})
|
.crossnote/style.less
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
/* Please visit the URL below for more information: */
|
| 3 |
+
/* https://shd101wyy.github.io/markdown-preview-enhanced/#/customize-css */
|
| 4 |
+
|
| 5 |
+
.markdown-preview.markdown-preview {
|
| 6 |
+
// modify your style here
|
| 7 |
+
// eg: background-color: blue;
|
| 8 |
+
}
|
.cursorrules
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# === BACKLOG.MD GUIDELINES START ===
|
| 3 |
+
# Instructions for the usage of Backlog.md CLI Tool
|
| 4 |
+
|
| 5 |
+
## 1. Source of Truth
|
| 6 |
+
|
| 7 |
+
- Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
|
| 8 |
+
- Every implementation decision starts with reading the corresponding Markdown task file.
|
| 9 |
+
- Project documentation is in **`backlog/docs/`**.
|
| 10 |
+
- Project decisions are in **`backlog/decisions/`**.
|
| 11 |
+
|
| 12 |
+
## 2. Defining Tasks
|
| 13 |
+
|
| 14 |
+
### Understand the Scope and the purpose
|
| 15 |
+
|
| 16 |
+
Ask questions to the user if something is not clear or ambiguous.
|
| 17 |
+
Break down the task into smaller, manageable parts if it is too large or complex.
|
| 18 |
+
|
| 19 |
+
### **Title (one liner)**
|
| 20 |
+
|
| 21 |
+
Use a clear brief title that summarizes the task.
|
| 22 |
+
|
| 23 |
+
### **Description**: (The **"why"**)
|
| 24 |
+
|
| 25 |
+
Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
|
| 26 |
+
should explain the purpose and context of the task. Code snippets should be avoided.
|
| 27 |
+
|
| 28 |
+
### **Acceptance Criteria**: (The **"what"**)
|
| 29 |
+
|
| 30 |
+
List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
|
| 31 |
+
`- [ ]`) for tracking.
|
| 32 |
+
When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
|
| 33 |
+
than step-by-step implementation details.
|
| 34 |
+
Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
|
| 35 |
+
They should be testable and confirm that the core purpose of the task is achieved.
|
| 36 |
+
**Key Principles for Good ACs:**
|
| 37 |
+
|
| 38 |
+
- **Outcome-Oriented:** Focus on the result, not the method.
|
| 39 |
+
- **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
|
| 40 |
+
- **Clear and Concise:** Unambiguous language.
|
| 41 |
+
- **Complete:** Collectively, ACs should cover the scope of the task.
|
| 42 |
+
- **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
|
| 43 |
+
|
| 44 |
+
- *Good Example:* "- [ ] User can successfully log in with valid credentials."
|
| 45 |
+
- *Good Example:* "- [ ] System processes 1000 requests per second without errors."
|
| 46 |
+
- *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
|
| 47 |
+
|
| 48 |
+
### Task file
|
| 49 |
+
|
| 50 |
+
Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
|
| 51 |
+
`task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
|
| 52 |
+
|
| 53 |
+
### Task Breakdown Strategy
|
| 54 |
+
|
| 55 |
+
When breaking down features:
|
| 56 |
+
|
| 57 |
+
1. Identify the foundational components first
|
| 58 |
+
2. Create tasks in dependency order (foundations before features)
|
| 59 |
+
3. Ensure each task delivers value independently
|
| 60 |
+
4. Avoid creating tasks that block each other
|
| 61 |
+
|
| 62 |
+
### Additional task requirements
|
| 63 |
+
|
| 64 |
+
- Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
|
| 65 |
+
Each task should represent a single unit of work that can be completed in a single PR.
|
| 66 |
+
|
| 67 |
+
- **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
|
| 68 |
+
previous
|
| 69 |
+
tasks (id < current task id).
|
| 70 |
+
|
| 71 |
+
- When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
|
| 72 |
+
Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
|
| 73 |
+
schema".
|
| 74 |
+
Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
|
| 75 |
+
schema", task 3: "Add API endpoint for user data".
|
| 76 |
+
|
| 77 |
+
## 3. Recommended Task Anatomy
|
| 78 |
+
|
| 79 |
+
```markdown
|
| 80 |
+
# task‑42 - Add GraphQL resolver
|
| 81 |
+
|
| 82 |
+
## Description (the why)
|
| 83 |
+
|
| 84 |
+
Short, imperative explanation of the goal of the task and why it is needed.
|
| 85 |
+
|
| 86 |
+
## Acceptance Criteria (the what)
|
| 87 |
+
|
| 88 |
+
- [ ] Resolver returns correct data for happy path
|
| 89 |
+
- [ ] Error response matches REST
|
| 90 |
+
- [ ] P95 latency ≤ 50 ms under 100 RPS
|
| 91 |
+
|
| 92 |
+
## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
|
| 93 |
+
|
| 94 |
+
1. Research existing GraphQL resolver patterns
|
| 95 |
+
2. Implement basic resolver with error handling
|
| 96 |
+
3. Add performance monitoring
|
| 97 |
+
4. Write unit and integration tests
|
| 98 |
+
5. Benchmark performance under load
|
| 99 |
+
|
| 100 |
+
## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
|
| 101 |
+
|
| 102 |
+
- Approach taken
|
| 103 |
+
- Features implemented or modified
|
| 104 |
+
- Technical decisions and trade-offs
|
| 105 |
+
- Modified or added files
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## 6. Implementing Tasks
|
| 109 |
+
|
| 110 |
+
Mandatory sections for every task:
|
| 111 |
+
|
| 112 |
+
- **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
|
| 113 |
+
change after the task is created, **the implementation plan must be added only after putting the task in progress**
|
| 114 |
+
and before starting working on the task.
|
| 115 |
+
- **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
|
| 116 |
+
section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
|
| 117 |
+
concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
|
| 118 |
+
others will read the code to understand the details.
|
| 119 |
+
|
| 120 |
+
**IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
|
| 121 |
+
implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
|
| 122 |
+
|
| 123 |
+
## 2. Typical Workflow
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
# 1 Identify work
|
| 127 |
+
backlog task list -s "To Do" --plain
|
| 128 |
+
|
| 129 |
+
# 2 Read details & documentation
|
| 130 |
+
backlog task 42 --plain
|
| 131 |
+
# Read also all documentation files in `backlog/docs/` directory.
|
| 132 |
+
# Read also all decision files in `backlog/decisions/` directory.
|
| 133 |
+
|
| 134 |
+
# 3 Start work: assign yourself & move column
|
| 135 |
+
backlog task edit 42 -a @{yourself} -s "In Progress"
|
| 136 |
+
|
| 137 |
+
# 4 Add implementation plan before starting
|
| 138 |
+
backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
|
| 139 |
+
|
| 140 |
+
# 5 Break work down if needed by creating subtasks or additional tasks
|
| 141 |
+
backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
|
| 142 |
+
|
| 143 |
+
# 6 Complete and mark Done
|
| 144 |
+
backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### 7. Final Steps Before Marking a Task as Done
|
| 148 |
+
|
| 149 |
+
Always ensure you have:
|
| 150 |
+
|
| 151 |
+
1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
|
| 152 |
+
2. ✅ Added an `## Implementation Notes` section documenting your approach
|
| 153 |
+
3. ✅ Run all tests and linting checks
|
| 154 |
+
4. ✅ Updated relevant documentation
|
| 155 |
+
|
| 156 |
+
## 8. Definition of Done (DoD)
|
| 157 |
+
|
| 158 |
+
A task is **Done** only when **ALL** of the following are complete:
|
| 159 |
+
|
| 160 |
+
1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
|
| 161 |
+
2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
|
| 162 |
+
3. **Automated tests** (unit + integration) cover new logic.
|
| 163 |
+
4. **Static analysis**: linter & formatter succeed.
|
| 164 |
+
5. **Documentation**:
|
| 165 |
+
- All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
|
| 166 |
+
- Task file **MUST** have an `## Implementation Notes` section added summarising:
|
| 167 |
+
- Approach taken
|
| 168 |
+
- Features implemented or modified
|
| 169 |
+
- Technical decisions and trade-offs
|
| 170 |
+
- Modified or added files
|
| 171 |
+
6. **Review**: self review code.
|
| 172 |
+
7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
|
| 173 |
+
8. **No regressions**: performance, security and licence checks green.
|
| 174 |
+
|
| 175 |
+
⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
|
| 176 |
+
|
| 177 |
+
## 9. Handy CLI Commands
|
| 178 |
+
|
| 179 |
+
| Action | Example |
|
| 180 |
+
|-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 181 |
+
| Create task | `backlog task create "Add OAuth System"` |
|
| 182 |
+
| Create with description | `backlog task create "Feature" -d "Add authentication system"` |
|
| 183 |
+
| Create with assignee | `backlog task create "Feature" -a @sara` |
|
| 184 |
+
| Create with status | `backlog task create "Feature" -s "In Progress"` |
|
| 185 |
+
| Create with labels | `backlog task create "Feature" -l auth,backend` |
|
| 186 |
+
| Create with priority | `backlog task create "Feature" --priority high` |
|
| 187 |
+
| Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
|
| 188 |
+
| Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
|
| 189 |
+
| Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
|
| 190 |
+
| Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
|
| 191 |
+
| Create sub task | `backlog task create -p 14 "Add Login with Google"` |
|
| 192 |
+
| Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
|
| 193 |
+
| List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
|
| 194 |
+
| List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
|
| 195 |
+
| View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
|
| 196 |
+
| View (AI mode) | `backlog task 7 --plain` |
|
| 197 |
+
| Edit | `backlog task edit 7 -a @sara -l auth,backend` |
|
| 198 |
+
| Add plan | `backlog task edit 7 --plan "Implementation approach"` |
|
| 199 |
+
| Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
|
| 200 |
+
| Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
|
| 201 |
+
| Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
|
| 202 |
+
| Archive | `backlog task archive 7` |
|
| 203 |
+
| Create draft | `backlog task create "Feature" --draft` |
|
| 204 |
+
| Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
|
| 205 |
+
| Demote to draft | `backlog task demote <id>` |
|
| 206 |
+
|
| 207 |
+
Full help: `backlog --help`
|
| 208 |
+
|
| 209 |
+
## 10. Tips for AI Agents
|
| 210 |
+
|
| 211 |
+
- **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
|
| 212 |
+
interactive UI.
|
| 213 |
+
- When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
|
| 214 |
+
|
| 215 |
+
# === BACKLOG.MD GUIDELINES END ===
|
.gitattributes
CHANGED
|
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/wav/helpful_base_0.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/wav/helpful_base_1.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/wav/helpful_base_2.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/wav/helpful_base_3.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/wav/helpful_base_4.wav filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/wav/helpful_base_5.wav filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/wav/helpful_base_6.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/wav/helpful_base_7.wav filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/wav/helpful_base_8.wav filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
examples/wav/helpful_base_9.wav filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
images/llama-omni2.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
llama_omni2/inference/prompt_en.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
llama_omni2/inference/prompt_zh.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
models/Llama-3.1-8B-Omni/images/model.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
tmp/e5fd5a073117d600c1ed49bd412158449e0e001ade31bc971dc1dcb45631c170/Tuesday[[:space:]]at[[:space:]]20-06.wav filter=lfs diff=lfs merge=lfs -text
|
2505.02625v1.txt
ADDED
|
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LLaMA-Omni 2: LLM-based Real-time Spoken Chatbot with
|
| 2 |
+
Autoregressive Streaming Speech Synthesis
|
| 3 |
+
Qingkai Fang1,3 , Yan Zhou1,3 , Shoutao Guo1,3 , Shaolei Zhang1,3 , Yang Feng1,2,3 *
|
| 4 |
+
1
|
| 5 |
+
Key Laboratory of Intelligent Information Processing
|
| 6 |
+
Institute of Computing Technology, Chinese Academy of Sciences (ICT/CAS)
|
| 7 |
+
2
|
| 8 |
+
Key Laboratory of AI Safety, Chinese Academy of Sciences
|
| 9 |
+
3
|
| 10 |
+
University of Chinese Academy of Sciences, Beijing, China
|
| 11 |
+
{fangqingkai21b,fengyang}@ict.ac.cn
|
| 12 |
+
|
| 13 |
+
arXiv:2505.02625v1 [cs.CL] 5 May 2025
|
| 14 |
+
|
| 15 |
+
Abstract
|
| 16 |
+
|
| 17 |
+
recognition (ASR) model, an LLM, and a text-tospeech (TTS) model. While this method is relatively straightforward to implement, it suffers from
|
| 18 |
+
several notable limitations. First, errors can accumulate across the different stages of the pipeline.
|
| 19 |
+
Second, the overall response latency tends to be
|
| 20 |
+
high due to the sequential processing of multiple models. Third, the system struggles to capture paralinguistic information present in the input
|
| 21 |
+
speech. To address these limitations, end-to-end
|
| 22 |
+
speech language models (SpeechLMs) have gradually gained more attention, using a single unified
|
| 23 |
+
model to handle the entire process from speech input to output. Overall, end-to-end SpeechLMs can
|
| 24 |
+
be categorized into two types: native and modular. Native SpeechLMs typically discretize speech
|
| 25 |
+
into tokens and employ a GPT-style decoder-only
|
| 26 |
+
Transformer (Radford, 2018) to model both speech
|
| 27 |
+
and text within a unified language model (Zhang
|
| 28 |
+
et al., 2023; Rubenstein et al., 2023; Hassid et al.,
|
| 29 |
+
2024a). A key advantage of this architecture is
|
| 30 |
+
its ability to leverage vast amounts of unsupervised speech data for pretraining, making it easier to scale up in terms of model parameters and
|
| 31 |
+
data size. This can potentially result in emergent capabilities, such as more human-like speech
|
| 32 |
+
expressiveness (Zeng et al., 2024a; Open-Moss,
|
| 33 |
+
2025). However, native SpeechLMs typically require large-scale speech datasets (e.g., millions of
|
| 34 |
+
hours) for pretraining (Zeng et al., 2024b; Défossez et al., 2024), which presents challenges in data
|
| 35 |
+
collection and training costs, and may also lead to
|
| 36 |
+
catastrophic forgetting of the model’s text capabilities. In contrast, modular SpeechLMs incorporate
|
| 37 |
+
a speech encoder and a speech decoder around the
|
| 38 |
+
LLM to handle speech understanding and generation (Fang et al., 2025; Wang et al., 2024). The
|
| 39 |
+
advantage of this approach is its ability to leverage
|
| 40 |
+
the inherent capabilities of each module, requiring
|
| 41 |
+
only small-scale fine-tuning (e.g., a few hundred
|
| 42 |
+
or thousand hours of speech data) to align the mod-
|
| 43 |
+
|
| 44 |
+
Real-time, intelligent, and natural speech interaction is an essential part of the next-generation
|
| 45 |
+
human-computer interaction. Recent advancements have showcased the potential of building
|
| 46 |
+
intelligent spoken chatbots based on large language models (LLMs). In this paper, we introduce LLaMA-Omni 2, a series of speech language models (SpeechLMs) ranging from 0.5B
|
| 47 |
+
to 14B parameters, capable of achieving highquality real-time speech interaction. LLaMAOmni 2 is built upon the Qwen2.5 series models, integrating a speech encoder and an autoregressive streaming speech decoder. Despite being trained on only 200K multi-turn speech dialogue samples, LLaMA-Omni 2 demonstrates
|
| 48 |
+
strong performance on several spoken question
|
| 49 |
+
answering and speech instruction following
|
| 50 |
+
benchmarks, surpassing previous state-of-theart SpeechLMs like GLM-4-Voice, which was
|
| 51 |
+
trained on millions of hours of speech data.1
|
| 52 |
+
|
| 53 |
+
1
|
| 54 |
+
|
| 55 |
+
Introduction
|
| 56 |
+
|
| 57 |
+
Speech, as a critical interface for human-computer
|
| 58 |
+
interaction, can significantly enhance both interaction efficiency and user experience (Clark et al.,
|
| 59 |
+
2019). In recent years, as large language models (LLMs) like ChatGPT (OpenAI, 2022) have
|
| 60 |
+
demonstrated outstanding performance across various fields, speech interactions with LLMs have
|
| 61 |
+
attracted widespread attention from both academia
|
| 62 |
+
and industry. For instance, GPT-4o (OpenAI, 2024)
|
| 63 |
+
enables real-time, intelligent, and natural speech
|
| 64 |
+
interaction between users and LLMs, heralding the
|
| 65 |
+
advent of a new generation of human-computer
|
| 66 |
+
interaction paradigms.
|
| 67 |
+
To develop a spoken chatbot similar to GPT-4o,
|
| 68 |
+
the traditional approach typically employs a cascaded pipeline comprising an automatic speech
|
| 69 |
+
* Corresponding author: Yang Feng.
|
| 70 |
+
1
|
| 71 |
+
|
| 72 |
+
Code: https://github.com/ictnlp/LLaMA-Omni2
|
| 73 |
+
Audio Samples: https://llama-omni2.github.io/
|
| 74 |
+
|
| 75 |
+
1
|
| 76 |
+
|
| 77 |
+
ules. This enables the model to acquire speech interaction capabilities at a relatively low cost, while
|
| 78 |
+
retaining most of its original capability. Moreover,
|
| 79 |
+
modular SpeechLMs can typically generate speech
|
| 80 |
+
guided by textual output, ensuring the intelligence
|
| 81 |
+
of the generated speech.
|
| 82 |
+
In addition to the intelligence of speech, realtime responsiveness and naturalness are also crucial characteristics of spoken chatbots. LLaMAOmni (Fang et al., 2025) uses a non-autoregressive
|
| 83 |
+
(NAR) streaming speech decoder to enable synchronized generation of speech and text, ensuring
|
| 84 |
+
extremely low response latency. However, due
|
| 85 |
+
to the limitations of non-autoregressive models in
|
| 86 |
+
modeling capacity, the generated speech is often
|
| 87 |
+
less natural and fluent. Freeze-Omni (Wang et al.,
|
| 88 |
+
2024) combines both NAR and autoregressive (AR)
|
| 89 |
+
models for speech generation, resulting in higher
|
| 90 |
+
naturalness of the generated speech. However, it
|
| 91 |
+
can only achieve sentence-level streaming speech
|
| 92 |
+
generation through a simple sentence-split strategy,
|
| 93 |
+
which prevents it from achieving very low response
|
| 94 |
+
latency. To address these challenges, in this paper,
|
| 95 |
+
we introduce LLaMA-Omni 2, a series of modular
|
| 96 |
+
SpeechLMs ranging from 0.5B to 14B. LLaMAOmni 2 adopts Qwen2.5-0.5B/1.5B/3B/7B/14BInstruct models (Team, 2024) as the base LLM,
|
| 97 |
+
and uses Whisper’s encoder (Radford et al., 2023)
|
| 98 |
+
as the speech encoder. For the speech decoder,
|
| 99 |
+
inspired by the state-of-the-art streaming speech
|
| 100 |
+
synthesis model CosyVoice 2 (Du et al., 2024), it
|
| 101 |
+
first includes an autoregressive text-to-speech language model initialized with Qwen2.5-0.5B, which
|
| 102 |
+
generates speech tokens from the LLM output and
|
| 103 |
+
achieves streaming generation through alternating
|
| 104 |
+
read and write operations. The speech tokens are
|
| 105 |
+
then passed through a chunk-aware causal flow
|
| 106 |
+
matching model (Lipman et al., 2023) to generate the mel spectrogram in a streaming manner.
|
| 107 |
+
To train the model, we synthesize 200K multiturn speech-to-speech dialogue samples with diverse input voices and a uniform output voice.
|
| 108 |
+
Experimental results show that LLaMA-Omni 2
|
| 109 |
+
achieves outstanding performance on spoken question answering and speech instruction following
|
| 110 |
+
tasks in both speech-to-text and speech-to-speech
|
| 111 |
+
settings, outperforming both LLaMA-Omni and
|
| 112 |
+
the native SpeechLM GLM-4-Voice (Zeng et al.,
|
| 113 |
+
2024a), which was trained on millions of hours
|
| 114 |
+
of speech data. We also conducted detailed ablation studies on factors such as LLM parameter size,
|
| 115 |
+
training data scale, speech decoder pretraining, and
|
| 116 |
+
|
| 117 |
+
read-write strategy, to better understand the impact
|
| 118 |
+
of these factors on the overall system performance.
|
| 119 |
+
|
| 120 |
+
2
|
| 121 |
+
|
| 122 |
+
Model: LLaMA-Omni 2
|
| 123 |
+
|
| 124 |
+
In this section, we introduce the model architecture
|
| 125 |
+
of LLaMA-Omni 2. As shown in Figure 1, the
|
| 126 |
+
core of LLaMA-Omni 2 is an LLM, for which we
|
| 127 |
+
use the Qwen2.5 series models (Team, 2024) due
|
| 128 |
+
to their strong performance across various benchmarks. Next, we will describe how we equip the
|
| 129 |
+
LLM with speech understanding and streaming
|
| 130 |
+
speech generation capabilities. In the following,
|
| 131 |
+
we use MLLM to denote the LLM. For a single-turn
|
| 132 |
+
instruction-response pair, we denote the speech instruction as X, and the text and speech responses
|
| 133 |
+
as Y T and Y S , respectively.
|
| 134 |
+
2.1
|
| 135 |
+
|
| 136 |
+
Speech Understanding
|
| 137 |
+
|
| 138 |
+
To enable speech understanding, we incorporate
|
| 139 |
+
a speech encoder and a speech adapter before the
|
| 140 |
+
LLM, similar to LLaMA-Omni (Fang et al., 2025).
|
| 141 |
+
Specifically, we use the encoder of Whisper-largev3 (Radford et al., 2023) as the speech encoder,
|
| 142 |
+
which converts the input speech into a sequence of
|
| 143 |
+
representations. The encoded representations are
|
| 144 |
+
then passed into the speech adapter, which consists
|
| 145 |
+
of a downsampling module and a feed-forward network (FFN). The downsampling module concatenates every k consecutive frames along the feature
|
| 146 |
+
dimension, and the concatenated representations
|
| 147 |
+
are further encoded by the FFN. The final output
|
| 148 |
+
representation is then input into the LLM.
|
| 149 |
+
2.2
|
| 150 |
+
|
| 151 |
+
Streaming Speech Generation
|
| 152 |
+
|
| 153 |
+
To equip the model with streaming speech generation capabilities, we adopt a paradigm similar to
|
| 154 |
+
CosyVoice 2 (Du et al., 2024). First, the speech
|
| 155 |
+
response is converted into discrete tokens using a
|
| 156 |
+
supervised semantic speech tokenizer. Then, an
|
| 157 |
+
autoregressive text-to-speech language model is
|
| 158 |
+
employed to model the streaming generation from
|
| 159 |
+
the LLM output to speech tokens. Finally, a causal
|
| 160 |
+
flow matching model converts speech tokens into
|
| 161 |
+
the mel spectrogram in a streaming manner.
|
| 162 |
+
Speech Tokenizer The speech tokenizer is implemented by inserting a finite scalar quantization
|
| 163 |
+
(FSQ) module (Mentzer et al., 2024) into the encoder of SenseVoice-Large ASR model (An et al.,
|
| 164 |
+
2024). This module first projects the intermediate
|
| 165 |
+
representations to a low-rank space and discretizes
|
| 166 |
+
them through a rounding operation. Ultimately,
|
| 167 |
+
2
|
| 168 |
+
|
| 169 |
+
…
|
| 170 |
+
|
| 171 |
+
latency
|
| 172 |
+
|
| 173 |
+
Gate Fusion Module
|
| 174 |
+
|
| 175 |
+
Flow Matching & Vocoder
|
| 176 |
+
|
| 177 |
+
Large Language Model
|
| 178 |
+
Speech Adaptor
|
| 179 |
+
Speech Encoder
|
| 180 |
+
|
| 181 |
+
+
|
| 182 |
+
<latexit sha1_base64="7CDz+hFii/hnzm/SPcG6JVj1JjA=">AAAB6HicbVDLSgNBEOyNrxhfUY9eBoMgCGFXJHoMevGYgHlAsoTZSW8yZnZ2mZkVQsgXePGgiFc/yZt/4yTZgyYWNBRV3XR3BYng2rjut5NbW9/Y3MpvF3Z29/YPiodHTR2nimGDxSJW7YBqFFxiw3AjsJ0opFEgsBWM7mZ+6wmV5rF8MOME/YgOJA85o8ZK9YteseSW3TnIKvEyUoIMtV7xq9uPWRqhNExQrTuemxh/QpXhTOC00E01JpSN6AA7lkoaofYn80On5MwqfRLGypY0ZK7+npjQSOtxFNjOiJqhXvZm4n9eJzXhjT/hMkkNSrZYFKaCmJjMviZ9rpAZMbaEMsXtrYQNqaLM2GwKNgRv+eVV0rwse5VypX5Vqt5mceThBE7hHDy4hircQw0awADhGV7hzXl0Xpx352PRmnOymWP4A+fzB3UXjLo=</latexit>
|
| 183 |
+
|
| 184 |
+
<latexit sha1_base64="IksX52OSp+tBzewG6ZihRejM6FQ=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69BIvgqSQi1WPRi8cKpi20oWw2m3bpZjfsToRS+hu8eFDEqz/Im//GbZuDVh8MPN6bYWZelAlu0PO+nNLa+sbmVnm7srO7t39QPTxqG5VrygKqhNLdiBgmuGQBchSsm2lG0kiwTjS+nfudR6YNV/IBJxkLUzKUPOGUoJWCvooVDqo1r+4t4P4lfkFqUKA1qH72Y0XzlEmkghjT870MwynRyKlgs0o/NywjdEyGrGepJCkz4XRx7Mw9s0rsJkrbkugu1J8TU5IaM0kj25kSHJlVby7+5/VyTK7DKZdZjkzS5aIkFy4qd/65G3PNKIqJJYRqbm916YhoQtHmU7Eh+Ksv/yXti7rfqDfuL2vNmyKOMpzAKZyDD1fQhDtoQQAUODzBC7w60nl23pz3ZWvJKWaO4Recj2/u8I7J</latexit>
|
| 185 |
+
|
| 186 |
+
Text-to-Speech Language Model
|
| 187 |
+
|
| 188 |
+
<latexit sha1_base64="IksX52OSp+tBzewG6ZihRejM6FQ=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69BIvgqSQi1WPRi8cKpi20oWw2m3bpZjfsToRS+hu8eFDEqz/Im//GbZuDVh8MPN6bYWZelAlu0PO+nNLa+sbmVnm7srO7t39QPTxqG5VrygKqhNLdiBgmuGQBchSsm2lG0kiwTjS+nfudR6YNV/IBJxkLUzKUPOGUoJWCvooVDqo1r+4t4P4lfkFqUKA1qH72Y0XzlEmkghjT870MwynRyKlgs0o/NywjdEyGrGepJCkz4XRx7Mw9s0rsJkrbkugu1J8TU5IaM0kj25kSHJlVby7+5/VyTK7DKZdZjkzS5aIkFy4qd/65G3PNKIqJJYRqbm916YhoQtHmU7Eh+Ksv/yXti7rfqDfuL2vNmyKOMpzAKZyDD1fQhDtoQQAUODzBC7w60nl23pz3ZWvJKWaO4Recj2/u8I7J</latexit>
|
| 189 |
+
|
| 190 |
+
1
|
| 191 |
+
<latexit sha1_base64="yWApgEffzdEH57mYnQQzN7vgc1w=">AAAB6XicbVBNS8NAEJ34WetX1aOXxSJ4sSQi1WPRi8cq9gPaUDbbSbt0swm7G6GE/gMvHhTx6j/y5r9x2+agrQ8GHu/NMDMvSATXxnW/nZXVtfWNzcJWcXtnd2+/dHDY1HGqGDZYLGLVDqhGwSU2DDcC24lCGgUCW8Hoduq3nlBpHstHM07Qj+hA8pAzaqz04J33SmW34s5AlomXkzLkqPdKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZpdOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymb5M+V8iMGFtCmeL2VsKGVFFmbDhFG4K3+PIyaV5UvGqlen9Zrt3kcRTgGE7gDDy4ghrcQR0awCCEZ3iFN2fkvDjvzse8dcXJZ47gD5zPH+d4jPc=</latexit>
|
| 192 |
+
|
| 193 |
+
<latexit sha1_base64="tdy5cBUx22e49sInllEMc7AaEZY=">AAAB7XicbVDLSgNBEOyNrxhfUY9eBoPgKeyKRI9BLx4jmAckS5idzCZj5rHMzAphyT948aCIV//Hm3/jJNmDJhY0FFXddHdFCWfG+v63V1hb39jcKm6Xdnb39g/Kh0cto1JNaJMornQnwoZyJmnTMstpJ9EUi4jTdjS+nfntJ6oNU/LBThIaCjyULGYEWye1eoYNBe6XK37VnwOtkiAnFcjR6Je/egNFUkGlJRwb0w38xIYZ1pYRTqelXmpogskYD2nXUYkFNWE2v3aKzpwyQLHSrqRFc/X3RIaFMRMRuU6B7cgsezPxP6+b2vg6zJhMUkslWSyKU46sQrPX0YBpSiyfOIKJZu5WREZYY2JdQCUXQrD88ippXVSDWrV2f1mp3+RxFOEETuEcAriCOtxBA5pA4BGe4RXePOW9eO/ex6K14OUzx/AH3ucPn/GPLg==</latexit>
|
| 194 |
+
|
| 195 |
+
Stage I(a)
|
| 196 |
+
FFN
|
| 197 |
+
|
| 198 |
+
Emb
|
| 199 |
+
Certainly!
|
| 200 |
+
|
| 201 |
+
TTS Language Model
|
| 202 |
+
|
| 203 |
+
Gate Fusion
|
| 204 |
+
Certainly!
|
| 205 |
+
|
| 206 |
+
Writing
|
| 207 |
+
|
| 208 |
+
high …
|
| 209 |
+
|
| 210 |
+
a
|
| 211 |
+
|
| 212 |
+
Stage I(b)
|
| 213 |
+
|
| 214 |
+
TTS Language Model
|
| 215 |
+
|
| 216 |
+
Large Language Model
|
| 217 |
+
|
| 218 |
+
Gate Fusion
|
| 219 |
+
Speech Representations
|
| 220 |
+
|
| 221 |
+
Speech Adaptor
|
| 222 |
+
|
| 223 |
+
LLM Hidden States
|
| 224 |
+
|
| 225 |
+
Speech Encoder
|
| 226 |
+
|
| 227 |
+
Fused Representations
|
| 228 |
+
Speech Tokens
|
| 229 |
+
Ignore Tokens
|
| 230 |
+
|
| 231 |
+
(Hey! Can you give me some
|
| 232 |
+
advices on writing NLP papers?)
|
| 233 |
+
|
| 234 |
+
Large Language Model
|
| 235 |
+
Speech Adaptor
|
| 236 |
+
Speech Encoder
|
| 237 |
+
Stage II
|
| 238 |
+
|
| 239 |
+
Figure 1: Left: Model architecture of LLaMA-Omni 2. Right: Illustration of the two-stage training strategy.
|
| 240 |
+
|
| 241 |
+
the speech response Y S is converted into a token
|
| 242 |
+
U ], with 25 tokens per
|
| 243 |
+
sequence Y U = [y1U , . . . , yM
|
| 244 |
+
second, where each token yiU ∈ {K ∈ N | 0 ≤
|
| 245 |
+
K < 6561}. We use the pretraiend speech tokenizer in CosyVoice 2.
|
| 246 |
+
|
| 247 |
+
MTTS , while also obtaining the text embeddings:
|
| 248 |
+
ehidden
|
| 249 |
+
= FFN(hi ),
|
| 250 |
+
i
|
| 251 |
+
|
| 252 |
+
(1)
|
| 253 |
+
|
| 254 |
+
eemb
|
| 255 |
+
= Emb(yiT ),
|
| 256 |
+
i
|
| 257 |
+
|
| 258 |
+
(2)
|
| 259 |
+
|
| 260 |
+
where Emb(·) is the embedding layer of MTTS . Afterward, we use an element-wise gate fusion mechanism to combine both representations. Specifically,
|
| 261 |
+
we compute the gate gi as follows:
|
| 262 |
+
|
| 263 |
+
Text-to-Speech Language Model After converting the speech response into discrete tokens, we
|
| 264 |
+
use a decoder-only Transformer (Vaswani, 2017)
|
| 265 |
+
to model the conditional language model from
|
| 266 |
+
the LLM output to the speech tokens, denoted as
|
| 267 |
+
MTTS . It is initialized with Qwen2.5-0.5B, and
|
| 268 |
+
its vocabulary is extended as V′ = V ∪ {< i >|
|
| 269 |
+
i ∈ N, 0 ≤ i < 6561}, where V is the original
|
| 270 |
+
vocabulary. This extension enables the model to
|
| 271 |
+
generate speech tokens.
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
gi = σ Wg ehidden
|
| 277 |
+
∥ eemb
|
| 278 |
+
+ bg ,
|
| 279 |
+
i
|
| 280 |
+
i
|
| 281 |
+
|
| 282 |
+
(3)
|
| 283 |
+
|
| 284 |
+
where ∥ denotes concatenation, σ is the sigmoid
|
| 285 |
+
function, and Wg ∈ R2d×d and bg ∈ Rd are the
|
| 286 |
+
weight and bias parameters of the gate, and d is
|
| 287 |
+
the embedding size of MTTS . Finally, the fused
|
| 288 |
+
representation is computed as:
|
| 289 |
+
|
| 290 |
+
The input to MTTS comes from the output of the
|
| 291 |
+
LLM. Specifically, the LLM output consists of two
|
| 292 |
+
parts: continuous hidden states and text tokens sampled from the hidden states. The former contains
|
| 293 |
+
contextual information, while the latter provides
|
| 294 |
+
precise textual content. We aim to use both as inputs to the text-to-speech language model. This allows the model to both consider the current context
|
| 295 |
+
and ensure better alignment with the text response
|
| 296 |
+
when generating speech tokens. During training,
|
| 297 |
+
the LLM is trained with teacher forcing, so its output hidden states are denoted as H = [h1 , ..., hN ],
|
| 298 |
+
T ). The corresponding
|
| 299 |
+
where hi = MLLM (X, Y<i
|
| 300 |
+
T
|
| 301 |
+
T ]. We first
|
| 302 |
+
text is the ground truth Y = [y1T , ..., yN
|
| 303 |
+
use a 2-layer feed-forward network (FFN) to map
|
| 304 |
+
the hidden states to the embedding dimension of
|
| 305 |
+
|
| 306 |
+
ci = gi ⊙ ehidden
|
| 307 |
+
+ (1 − gi ) ⊙ eemb
|
| 308 |
+
i
|
| 309 |
+
i ,
|
| 310 |
+
|
| 311 |
+
(4)
|
| 312 |
+
|
| 313 |
+
where ⊙ denotes element-wise multiplication. This
|
| 314 |
+
fused representations C = [c1 , ..., cN ] are then
|
| 315 |
+
passed to MTTS for generating speech tokens.
|
| 316 |
+
To achieve streaming generation, i.e., to generate
|
| 317 |
+
speech tokens simultaneously during the LLM’s
|
| 318 |
+
output process, we adopt a “Read-R-Write-W”
|
| 319 |
+
strategy, similar to CosyVoice 2. Specifically, we
|
| 320 |
+
mix the fused representation C and the speech tokens Y U at a predefined ratio R : W. For every R
|
| 321 |
+
fused representations read in, the model generates
|
| 322 |
+
W speech tokens. Once all fused representations
|
| 323 |
+
are read, the model continues to generate the remaining speech tokens until completion. During
|
| 324 |
+
3
|
| 325 |
+
|
| 326 |
+
training, cross-entropy loss is computed only for
|
| 327 |
+
the generated speech tokens as follows:
|
| 328 |
+
|
| 329 |
+
2.4
|
| 330 |
+
|
| 331 |
+
2.3
|
| 332 |
+
|
| 333 |
+
3
|
| 334 |
+
|
| 335 |
+
Inference
|
| 336 |
+
|
| 337 |
+
During inference, the LLM autoregressively generates the text response based on the speech instrucM
|
| 338 |
+
X
|
| 339 |
+
tion. After generating R text tokens, its hidden
|
| 340 |
+
U
|
| 341 |
+
LTTS = −
|
| 342 |
+
log P (yiU |C≤min(⌊ i−1 +1⌋·R,N ) , Y<i
|
| 343 |
+
), states and the corresponding decoded text are fed
|
| 344 |
+
W
|
| 345 |
+
i=1
|
| 346 |
+
(5) into the gate fusion module and MTTS to generate
|
| 347 |
+
where C≤min(⌊ i−1 +1⌋·R,N ) denotes the fused rep- W speech tokens, which are then passed through
|
| 348 |
+
W
|
| 349 |
+
the flow matching model and the vocoder to syntheresentations that have already been read.
|
| 350 |
+
size a speech chunk. In this way, text and speech
|
| 351 |
+
responses can be generated simultaneously. The
|
| 352 |
+
Flow Matching Model The speech tokens gen- response latency for the first synthesized speech
|
| 353 |
+
erated by MTTS are further processed by a chunk- chunk can be calculated as:
|
| 354 |
+
aware causal flow matching model (Lipman et al.,
|
| 355 |
+
2023) to synthesize the mel spectrogram in a
|
| 356 |
+
Ttotal = TLLM (R)+TTTS (W)+TFM (W)+TVoc (2W),
|
| 357 |
+
streaming manner. Every time W speech tokens
|
| 358 |
+
(6)
|
| 359 |
+
are generated, they are treated as a chunk for
|
| 360 |
+
where TLLM (R) and TTTS (W) represent the time
|
| 361 |
+
mel spectrogram synthesis. The synthesized mel
|
| 362 |
+
required by the MLLM and MTTS models to genspectrogram is then passed through a HiFi-GAN
|
| 363 |
+
erate R and W tokens, respectively. TFM (W) and
|
| 364 |
+
vocoder (Kong et al., 2020) to generate the final
|
| 365 |
+
TVoc (2W) represent the decoding times of the flow
|
| 366 |
+
waveform. We use the pretrained flow matching
|
| 367 |
+
matching model and vocoder when the inputs are
|
| 368 |
+
model and vocoder in CosyVoice 2.
|
| 369 |
+
W and 2W tokens2 , respectively.
|
| 370 |
+
Training
|
| 371 |
+
|
| 372 |
+
Data Construction
|
| 373 |
+
|
| 374 |
+
In this section, we introduce the process of constructing multi-turn speech-to-speech dialogue data.
|
| 375 |
+
Our data is an extension of the InstructS2S-200K
|
| 376 |
+
dataset introduced in Fang et al. (2025), which
|
| 377 |
+
contains 200K single-turn instruction-following
|
| 378 |
+
samples designed for speech interaction scenarios.
|
| 379 |
+
These samples are derived from the Alpaca (Taori
|
| 380 |
+
et al., 2023) and UltraChat (Ding et al., 2023)
|
| 381 |
+
datasets through rewriting using LLMs. Specifically, for each sample, we first sample the number of turns from a Poisson distribution: N ∼
|
| 382 |
+
Poisson(λ = 2), then clip N to the range of 1 to 5.
|
| 383 |
+
Next, we use the Llama-3.3-70B-Instruct3 (Dubey
|
| 384 |
+
et al., 2024) model to iteratively generate the dialog.
|
| 385 |
+
For the i-th turn, the instruction and response are
|
| 386 |
+
generated based on the dialogue history of previous
|
| 387 |
+
i − 1 turns. In this way, we obtain 200K multi-turn
|
| 388 |
+
text dialog samples.
|
| 389 |
+
Next, we need to convert the text dialogue into
|
| 390 |
+
speech. To simulate real-world applications, we
|
| 391 |
+
aim to have varied voices for the instruction, while
|
| 392 |
+
maintaining a consistent voice for the response.
|
| 393 |
+
For each multi-turn dialogue, we first use the fishspeech-1.54 model (Liao et al., 2024) to synthesize
|
| 394 |
+
|
| 395 |
+
The training of LLaMA-Omni 2 relies solely on
|
| 396 |
+
200K multi-turn speech-to-speech dialogue data
|
| 397 |
+
(we will describe how this is synthesized in Section 3) and does not use any ASR or TTS data. We
|
| 398 |
+
find that it is sufficient to achieve excellent performance while minimizing training costs. Specifically, the training process consists of two stages, as
|
| 399 |
+
shown in Figure 1.
|
| 400 |
+
Stage I In Stage I training, we train the speechto-text and text-to-speech components separately.
|
| 401 |
+
The training data consists of <speech instruction,
|
| 402 |
+
text response> pairs and <text response, speech response> pairs from the multi-turn speech-to-speech
|
| 403 |
+
dialogue data. Specifically, for the speech-to-text
|
| 404 |
+
part (Stage I(a)), we freeze the speech encoder
|
| 405 |
+
and train the speech adapter and LLM with crossentropy loss. For the text-to-speech part (Stage
|
| 406 |
+
I(b)), we train the text-to-speech language model
|
| 407 |
+
with cross-entropy loss. Note that during this stage,
|
| 408 |
+
the gate fusion module is not trained, and only text
|
| 409 |
+
embeddings are input into MTTS .
|
| 410 |
+
Stage II In Stage II, we train the model’s speechto-speech generation capability with speech-tospeech dialogue data. During this stage, we freeze
|
| 411 |
+
the speech encoder, speech adapter, and LLM, and
|
| 412 |
+
only train the gate fusion module and MTTS .
|
| 413 |
+
|
| 414 |
+
2
|
| 415 |
+
|
| 416 |
+
The length of the mel spectrogram is twice that of the
|
| 417 |
+
speech tokens (50 Hz vs. 25 Hz).
|
| 418 |
+
3
|
| 419 |
+
https://huggingface.co/meta-llama/Llama-3.
|
| 420 |
+
3-70B-Instruct
|
| 421 |
+
4
|
| 422 |
+
https://huggingface.co/fishaudio/
|
| 423 |
+
|
| 424 |
+
4
|
| 425 |
+
|
| 426 |
+
4.2
|
| 427 |
+
|
| 428 |
+
a short prompt (e.g., "This is a randomly generated
|
| 429 |
+
voice") with a random voice. Then, we use the synthesized speech as the prompt for the CosyVoice20.5B5 model, which synthesize the instruction into
|
| 430 |
+
speech while simultaneously cloning the voice.
|
| 431 |
+
This ensures consistency in the voice across different turns of the dialogue, while maintaining diversity across dialogues. For all responses, we use
|
| 432 |
+
a uniform voice as the prompt and then synthesize
|
| 433 |
+
the speech using the CosyVoice2-0.5B model.
|
| 434 |
+
|
| 435 |
+
4
|
| 436 |
+
|
| 437 |
+
Experiments
|
| 438 |
+
|
| 439 |
+
4.1
|
| 440 |
+
|
| 441 |
+
Experimental Setups
|
| 442 |
+
|
| 443 |
+
Evaluation
|
| 444 |
+
|
| 445 |
+
Our evaluation includes two tasks: spoken question answering and speech instruction following.
|
| 446 |
+
For both tasks, we evaluate the model’s speech-totext and speech-to-speech capabilities. The speechto-speech evaluation is done by transcribing the
|
| 447 |
+
speech response into text using the Whisper-largev3 model, and then applying the same evaluation
|
| 448 |
+
method as used for speech-to-text evaluation. In
|
| 449 |
+
all experiments, we use greedy search for the LLM
|
| 450 |
+
to ensure stable results. For the text-to-speech language model, we use sampling with temperature set
|
| 451 |
+
to 1.0, as we find that using greedy search causes
|
| 452 |
+
the model to fall into repetition.
|
| 453 |
+
|
| 454 |
+
Model Configuration We use the encoder of
|
| 455 |
+
Whisper-large-v3 as the speech encoder. The
|
| 456 |
+
speech adapter first performs a 5× downsampling, followed by a FFN with an intermediate
|
| 457 |
+
dimension of 2048. For the LLM, we select
|
| 458 |
+
the Qwen2.5 series models, including Qwen2.50.5B/1.5B/3B/7B/14B-Instruct models. We refer
|
| 459 |
+
to the corresponding models as LLaMA-Omni20.5B/1.5B/3B/7B/14B in the following sections.
|
| 460 |
+
For the text-to-speech language model, we initialize it with the Qwen2.5-0.5B model and set the
|
| 461 |
+
read-write strategy with R = 3 and W = 10.
|
| 462 |
+
We will discuss the impact of these hyperparameters on speech quality and response latency later.
|
| 463 |
+
The speech tokenizer, flow matching model, and
|
| 464 |
+
vocoder are directly taken from CosyVoice 2.
|
| 465 |
+
|
| 466 |
+
Spoken Question Answering The speech question answering (SpokenQA) task involves asking
|
| 467 |
+
the model spoken questions, then checking whether
|
| 468 |
+
the reference answer appears in the model’s response, and calculating the accuracy. We evaluate our model on two benchmarks: Llama Questions6 (Nachmani et al., 2024) and Web Questions7 (Berant et al., 2013). Since the questions in
|
| 469 |
+
the Web Questions dataset are in text form, we use
|
| 470 |
+
CosyVoice2-0.5B to synthesize them into speech.
|
| 471 |
+
Speech Instruction Following For the speech
|
| 472 |
+
instruction following task, we follow the settings
|
| 473 |
+
in Fang et al. (2025), selecting the helpful_base
|
| 474 |
+
and vicuna subsets from the Alpaca-Eval8 (Li et al.,
|
| 475 |
+
2023) dataset, excluding math and code-related instructions. The remaining 199 instructions are then
|
| 476 |
+
synthesized into speech for evaluation. Following Fang et al. (2025), we evaluate the model using
|
| 477 |
+
the following metrics:
|
| 478 |
+
ChatGPT Score: To evaluate the model’s ability to follow instructions, we use GPT-4o (OpenAI,
|
| 479 |
+
2024) to score the model’s responses. It considers
|
| 480 |
+
factors such as helpfulness, relevance, fluency, and
|
| 481 |
+
suitability for speech interaction scenarios, and assigns a single score between 1 and 5. The detailed
|
| 482 |
+
prompt can be found in Appendix A.
|
| 483 |
+
ASR-WER: To assess the consistency between
|
| 484 |
+
model’s text and speech responses, we use Whisperlarge-v3 to transcribe the speech response into text,
|
| 485 |
+
and calculate the word error rate (WER) between
|
| 486 |
+
the transcribed text and text response. We perform
|
| 487 |
+
|
| 488 |
+
Training Details We use the 200K multi-turn
|
| 489 |
+
speech-to-speech dialogue data from Section 3 for
|
| 490 |
+
two-stage training. In Stage I(a), we freeze the
|
| 491 |
+
speech encoder and train all parameters of the
|
| 492 |
+
speech adaptor and LLM. The batch size is 32,
|
| 493 |
+
and we train for 3 epochs with a peak learning
|
| 494 |
+
rate of 5e-5. In Stage I(b), we train the text-tospeech language model with a batch size of 32 for
|
| 495 |
+
5 epochs and a peak learning rate of 5e-4. In Stage
|
| 496 |
+
II, we freeze the speech encoder, speech adaptor,
|
| 497 |
+
and LLM, and train the remaining components with
|
| 498 |
+
a batch size of 32 for 1 epoch and a peak learning
|
| 499 |
+
rate of 1e-3. For all stages, we use a warmup strategy for the first 3% of steps and a cosine annealing
|
| 500 |
+
learning rate scheduler. The LLaMA-Omni2-14B
|
| 501 |
+
model is trained on 4 NVIDIA H800 GPUs, while
|
| 502 |
+
other models are trained on 4 NVIDIA L40 GPUs.
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
Table 1: Results on speech question answering and speech instruction following benchmarks. S2T and S2S represent
|
| 506 |
+
speech-to-text and speech-to-speech, respectively. We set R = 3 and W = 10 for all LLaMA-Omni2 series models.
|
| 507 |
+
|
| 508 |
+
text normalization9 before calculating the WER.
|
| 509 |
+
UTMOS: To evaluate the naturalness of the generated speech, we use the UTMOS model10 (Saeki
|
| 510 |
+
et al., 2022) to predict the mean opinion score
|
| 511 |
+
(MOS) of the generated speech.
|
| 512 |
+
Latency: We measure the time from receiving
|
| 513 |
+
the speech instruction to generating the first speech
|
| 514 |
+
chunk on a single NVIDIA L40 GPU.
|
| 515 |
+
4.3
|
| 516 |
+
|
| 517 |
+
5
|
| 518 |
+
|
| 519 |
+
Results and Analysis
|
| 520 |
+
|
| 521 |
+
5.1
|
| 522 |
+
|
| 523 |
+
Main Results
|
| 524 |
+
|
| 525 |
+
Table 1 presents the main results on the speech
|
| 526 |
+
question answering and speech instruction following benchmarks.
|
| 527 |
+
Spoken Question Answering For the SpokenQA
|
| 528 |
+
task, we observe that: (1) For models with similar
|
| 529 |
+
parameter sizes, LLaMA-Omni2-7B outperforms
|
| 530 |
+
both GLM-4-Voice and LLaMA-Omni in both S2T
|
| 531 |
+
and S2S settings. Notably, our model significantly
|
| 532 |
+
reduces the gap between S2T and S2S performance. For example, on the Web Questions benchmark, GLM-4-Voice drops by 16.3 (32.2→15.9),
|
| 533 |
+
LLaMA-Omni drops by 9.7 (33.4→23.7), while
|
| 534 |
+
LLaMA-Omni2-7B only drops by 3.2 (34.5→31.3),
|
| 535 |
+
demonstrating that our approach largely improves
|
| 536 |
+
speech generation capabilities. (2) For models with
|
| 537 |
+
varying parameter sizes, we observe that accuracy
|
| 538 |
+
increases as the LLM size grows, indicating that
|
| 539 |
+
LLaMA-Omni 2 effectively leverages the LLM’s
|
| 540 |
+
inherent capabilities. For smaller models, LLaMAOmni2-1.5B/3B exceeds the accuracy of GLM-4Voice and LLaMA-Omni in the S2S setting, making them suitable choices for edge devices. For
|
| 541 |
+
larger models, we observe a significant accuracy
|
| 542 |
+
improvement with LLaMA-Omni2-14B compared
|
| 543 |
+
to LLaMA-Omni2-7B, highlighting the potential
|
| 544 |
+
of our approach for scaling to larger models.
|
| 545 |
+
|
| 546 |
+
Baseline Systems
|
| 547 |
+
|
| 548 |
+
We primarily compare LLaMA-Omni 2 with the
|
| 549 |
+
following baseline systems:
|
| 550 |
+
LLaMA-Omni (Fang et al., 2025): One of the
|
| 551 |
+
earliest SpeechLMs that achieves real-time speech
|
| 552 |
+
interaction, by using a CTC-based (Graves et al.,
|
| 553 |
+
2006) streaming speech decoder to simultaneously
|
| 554 |
+
generate text and speech units. The generated units
|
| 555 |
+
are fed into the vocoder for streaming synthesis in
|
| 556 |
+
fixed-size chunks. We set the chunk size Ω = 40.
|
| 557 |
+
GLM-4-Voice (Zeng et al., 2024a): The current state-of-the-art native SpeechLM, pretrained
|
| 558 |
+
on millions of hours of speech data. It enables realtime speech interaction by alternately generating
|
| 559 |
+
text and speech tokens in a fixed ratio of 13:26.
|
| 560 |
+
The generated speech tokens are input into a flow
|
| 561 |
+
matching model with a fixed chunk size.
|
| 562 |
+
In addition, we also borrow some results
|
| 563 |
+
from Zeng et al. (2024a), including results of
|
| 564 |
+
TWIST (Hassid et al., 2024b), SpeechGPT (Zhang
|
| 565 |
+
et al., 2023), Spectron (Nachmani et al., 2024), and
|
| 566 |
+
Moshi (Défossez et al., 2024).
|
| 567 |
+
|
| 568 |
+
Speech Instruction Following For the speech
|
| 569 |
+
instruction following task, we observe that:
|
| 570 |
+
(1) LLaMA-Omni2-3B/7B/14B outperforms both
|
| 571 |
+
GLM-4-Voice and LLaMA-Omni in the S2T and
|
| 572 |
+
S2S settings, demonstrating the strong instructionfollowing capabilities of our models. (2) Similar
|
| 573 |
+
|
| 574 |
+
9
|
| 575 |
+
https://github.com/openai/whisper/blob/main/
|
| 576 |
+
whisper/normalizers/english.py
|
| 577 |
+
10
|
| 578 |
+
https://github.com/tarepan/SpeechMOS
|
| 579 |
+
798.99
|
| 580 |
+
-
|
| 581 |
+
|
| 582 |
+
Table 4: Ablation study on the read/write strategy with
|
| 583 |
+
LLaMA-Omni2-7B. “Offline” means generating speech
|
| 584 |
+
tokens only after receiving the complete input, and then
|
| 585 |
+
synthesizing all speech tokens into waveform at once.
|
| 586 |
+
|
| 587 |
+
fusing them with the gate fusion module.
|
| 588 |
+
|
| 589 |
+
Table 3: Ablation study on different TTS pretraining
|
| 590 |
+
strategies with LLaMA-Omni2-7B.
|
| 591 |
+
|
| 592 |
+
TTS Pretraining Our text-to-speech language
|
| 593 |
+
model is initialized with the Qwen2.5-0.5B model
|
| 594 |
+
and undergoes streaming TTS pretraining using
|
| 595 |
+
text-speech pairs from speech dialogue data in
|
| 596 |
+
Stage I(b) (R = 3, W = 10). We also explore
|
| 597 |
+
several other strategies, as shown in Table 3. “Offline TTS” refers to pretraining with the offline TTS
|
| 598 |
+
task on top of Qwen2.5-0.5B, which shows a slight
|
| 599 |
+
performance drop compared to the streaming TTS
|
| 600 |
+
pretraining. “Text Pretrained” refers to directly initializing with Qwen2.5-0.5B (with the extended
|
| 601 |
+
vocabulary including speech tokens), and we observe a significant performance decline. “Scratch”
|
| 602 |
+
refers to a randomly initialized model, whose loss
|
| 603 |
+
fails to converge within a short period. These experiments demonstrate the importance of pretraining
|
| 604 |
+
for the TTS language model.
|
| 605 |
+
|
| 606 |
+
to the results on SpokenQA benchmarks, we observe that model performance improves as the LLM
|
| 607 |
+
size increases, with LLaMA-Omni2-14B achieving
|
| 608 |
+
significantly better performance. (3) The models’
|
| 609 |
+
ASR-WER is generally low, significantly lower
|
| 610 |
+
than previous models, proving that our models
|
| 611 |
+
maintain strong consistency between the text and
|
| 612 |
+
speech responses. (4) Regarding speech quality,
|
| 613 |
+
thanks to the CosyVoice 2’s strong causal flow
|
| 614 |
+
matching model, our models achieve good UTMOS
|
| 615 |
+
scores under streaming synthesis, significantly outperforming the baseline models. (5) The latency
|
| 616 |
+
of LLaMA-Omni 2 is around 600ms. Although it
|
| 617 |
+
is slightly higher than LLaMA-Omni, it still meets
|
| 618 |
+
the requirements for real-time interaction and is
|
| 619 |
+
significantly lower than that of GLM-4-Voice.
|
| 620 |
+
5.2
|
| 621 |
+
|
| 622 |
+
W
|
| 623 |
+
|
| 624 |
+
1
|
| 625 |
+
5
|
| 626 |
+
2 10
|
| 627 |
+
3 10
|
| 628 |
+
3 15
|
| 629 |
+
4 15
|
| 630 |
+
5 20
|
| 631 |
+
Offline
|
| 632 |
+
|
| 633 |
+
Read/Write Strategy The read/write strategies
|
| 634 |
+
of the TTS language model is a key factor influencing performance, primarily affecting the speech
|
| 635 |
+
quality and system response latency. As shown
|
| 636 |
+
in Table 4, we explore different combinations of
|
| 637 |
+
R and W. First, we observe that when R = 3
|
| 638 |
+
and W = 10, the ASR-WER is the lowest, indicating the best alignment between speech and text
|
| 639 |
+
responses. As for the UTMOS score, we find that
|
| 640 |
+
it is primarily determined by W, as W represents
|
| 641 |
+
the chunk size of speech tokens input to the flow
|
| 642 |
+
matching model, with larger chunk sizes leading to
|
| 643 |
+
better speech quality. Regarding response latency,
|
| 644 |
+
it is jointly determined by R and W, as shown
|
| 645 |
+
in Equation 6. Without any engineering optimizations, LLaMA-Omni2-7B can achieve a latency
|
| 646 |
+
below 500ms. We choose R = 3 and W = 10 in
|
| 647 |
+
our main experiments because it provides a good
|
| 648 |
+
trade-off across all aspects.
|
| 649 |
+
|
| 650 |
+
Ablation Studies
|
| 651 |
+
|
| 652 |
+
To understand the impact of different factors on
|
| 653 |
+
overall performance, we conduct a series of ablation studies on the LLaMA-Omni2-7B model.
|
| 654 |
+
Gate Fusion Module Table 2 shows the ablation
|
| 655 |
+
study on the gate fusion module. Gate fusion module allows the model to adaptively fuse LLM hidden states and text embeddings, considering both
|
| 656 |
+
contextual information and textual content. When
|
| 657 |
+
the gate fusion module is removed and the two components are simply added together (ehidden
|
| 658 |
+
+ eemb
|
| 659 |
+
i
|
| 660 |
+
i )
|
| 661 |
+
as input to the text-to-speech language model, we
|
| 662 |
+
observe a decrease in performance. Further removing the text embedding and only inputting the
|
| 663 |
+
hidden states (ehidden
|
| 664 |
+
) results in a further perfori
|
| 665 |
+
mance decline. This validates the effectiveness of
|
| 666 |
+
adding text embeddings as input and adaptively
|
| 667 |
+
7
|
| 668 |
+
|
| 669 |
+
Table 5: Results under different training data sizes with LLaMA-Omni2-7B.
|
| 670 |
+
|
| 671 |
+
5.3
|
| 672 |
+
|
| 673 |
+
Effects of the Training Data Sizes
|
| 674 |
+
|
| 675 |
+
IntrinsicVoice (Zhang et al., 2024c) proposes a
|
| 676 |
+
GroupFormer architecture to shorten speech length
|
| 677 |
+
to be closer to that of text. In contrast to native SpeechLMs, modular SpeechLMs add speechrelated modules on top of LLMs. Early works
|
| 678 |
+
achieve speech understanding tasks by combining
|
| 679 |
+
speech encoders with LLMs, but are unable to perform speech generation (Wu et al., 2023; Wang
|
| 680 |
+
et al., 2023; Chu et al., 2023; Yu et al., 2024; Ma
|
| 681 |
+
et al., 2024b; Hono et al., 2024; Chen et al., 2024b;
|
| 682 |
+
Tang et al., 2024; Chu et al., 2024; Fathullah et al.,
|
| 683 |
+
2024). To achieve speech generation, LLaMAOmni (Fang et al., 2025), Freeze-Omni (Wang
|
| 684 |
+
et al., 2024), and OpenOmni (Luo et al., 2025) add
|
| 685 |
+
a speech decoder after LLMs. Mini-Omni (Xie
|
| 686 |
+
and Wu, 2024) and SLAM-Omni (Chen et al.,
|
| 687 |
+
2024a) enable LLMs to generate speech tokens
|
| 688 |
+
simultaneously while generating text tokens. The
|
| 689 |
+
most related work to ours is the concurrent work
|
| 690 |
+
Minmo (Chen et al., 2025), which also adopts an
|
| 691 |
+
autoregressive streaming speech decoder similar
|
| 692 |
+
to CosyVoice 2. In comparison, Minmo is trained
|
| 693 |
+
on 1.4M hours of data, while we train on only a
|
| 694 |
+
few thousand hours of data, providing a more efficient training solution. Additionally, we conduct
|
| 695 |
+
detailed ablation studies on LLM sizes, read-write
|
| 696 |
+
strategies, and model architecture to offer a more
|
| 697 |
+
comprehensive understanding of the model.
|
| 698 |
+
|
| 699 |
+
We explore the impact of different training data
|
| 700 |
+
sizes on performance. As shown in Table 5, we
|
| 701 |
+
first observe that, with the same number of training samples, multi-turn dialogue data consistently
|
| 702 |
+
achieves better results across all benchmarks compared to single-turn dialogue data, highlighting
|
| 703 |
+
the effectiveness of multi-turn dialogue data for
|
| 704 |
+
training. Additionally, for different training data
|
| 705 |
+
sizes, we observe that as the data size increases,
|
| 706 |
+
the model’s performance improves, gradually stabilizing at 200K training samples. This indicates
|
| 707 |
+
that our 200K multi-turn dialogue data is generally
|
| 708 |
+
sufficient while ensuring efficient training.
|
| 709 |
+
|
| 710 |
+
6
|
| 711 |
+
|
| 712 |
+
Related Work
|
| 713 |
+
|
| 714 |
+
With the rapid development of LLMs, SpeechLMs
|
| 715 |
+
have gained widespread attention in recent
|
| 716 |
+
years (Cui et al., 2024; Ji et al., 2024), aiming to
|
| 717 |
+
endow LLMs with the ability to understand or generate speech. Generally speaking, SpeechLMs can
|
| 718 |
+
be divided into two categories: native SpeechLMs
|
| 719 |
+
and modular SpeechLMs. Native SpeechLMs refer to decoder-only Transformer models capable
|
| 720 |
+
of directly inputting and outputting speech tokens.
|
| 721 |
+
Some early works include SpeechGPT (Zhang
|
| 722 |
+
et al., 2023, 2024a), AudioPaLM (Rubenstein
|
| 723 |
+
et al., 2023), and TWIST (Hassid et al., 2024a).
|
| 724 |
+
These models first convert speech into discrete
|
| 725 |
+
tokens, then extend the vocabulary of pretrained
|
| 726 |
+
LLMs to include these tokens, and finally train the
|
| 727 |
+
LLMs using a large amount of speech or speechtext pair data. Spirit-LM (Nguyen et al., 2024)
|
| 728 |
+
and GLM-4-Voice (Zeng et al., 2025, 2024a) propose training models using speech-text interleaved
|
| 729 |
+
data to encourage cross-modal knowledge transfer.
|
| 730 |
+
Moshi (Défossez et al., 2024), OmniFlatten (Zhang
|
| 731 |
+
et al., 2024b) and LSLM (Ma et al., 2024a) propose models capable of full-duplex conversations.
|
| 732 |
+
|
| 733 |
+
7
|
| 734 |
+
|
| 735 |
+
Conclusion
|
| 736 |
+
|
| 737 |
+
In this paper, we introduce LLaMA-Omni 2, a series of speech language models ranging from 0.5B
|
| 738 |
+
to 14B parameters, designed to enable real-time,
|
| 739 |
+
high-quality speech interaction. LLaMA-Omni
|
| 740 |
+
2 achieves streaming speech generation by integrating an autoregressive text-to-speech language
|
| 741 |
+
model and a causal flow matching model. Experimental results on spoken question answering
|
| 742 |
+
and speech instruction following tasks show that
|
| 743 |
+
8
|
| 744 |
+
|
| 745 |
+
LLaMA-Omni 2 outperforms previous state-of-theart speech language models, including LLaMAOmni and GLM-4-Voice. Additionally, LLaMAOmni 2 can achieve latency under 600ms, meeting
|
| 746 |
+
real-time interaction requirements. We also conduct detailed ablation studies to understand the
|
| 747 |
+
impact of various factors on overall performance.
|
| 748 |
+
In the future, we will explore enhancing LLaMAOmni 2 to generate more human-like speech, incorporating features such as emotion and dialects.
|
| 749 |
+
|
| 750 |
+
Wenxi Chen, Ziyang Ma, Ruiqi Yan, Yuzhe Liang, Xiquan Li, Ruiyang Xu, Zhikang Niu, Yanqiao Zhu,
|
| 751 |
+
Yifan Yang, Zhanxun Liu, et al. 2024a. Slamomni: Timbre-controllable voice interaction system with single-stage training. arXiv preprint
|
| 752 |
+
arXiv:2412.15649.
|
| 753 |
+
Xi Chen, Songyang Zhang, Qibing Bai, Kai Chen, and
|
| 754 |
+
Satoshi Nakamura. 2024b. LLaST: Improved endto-end speech translation system leveraged by large
|
| 755 |
+
language models. In Findings of the Association for
|
| 756 |
+
Computational Linguistics ACL 2024, pages 6976–
|
| 757 |
+
6987, Bangkok, Thailand and virtual meeting. Association for Computational Linguistics.
|
| 758 |
+
|
| 759 |
+
Limitations
|
| 760 |
+
|
| 761 |
+
Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei,
|
| 762 |
+
Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng
|
| 763 |
+
He, Junyang Lin, Chang Zhou, and Jingren Zhou.
|
| 764 |
+
2024. Qwen2-audio technical report. arXiv preprint
|
| 765 |
+
arXiv:2407.10759.
|
| 766 |
+
|
| 767 |
+
One limitation of our model is that currently it
|
| 768 |
+
cannot generate speech responses with different
|
| 769 |
+
styles (such as emotion or speech rate) based on
|
| 770 |
+
the content of the input speech or underlying paralinguistic information, as we have only trained on
|
| 771 |
+
conventional speech-to-speech dialogue data. However, we believe this functionality can be achieved
|
| 772 |
+
through a data-driven approach, as our model is
|
| 773 |
+
end-to-end trained and could acquire this capability after further training with suitable data. We plan
|
| 774 |
+
to explore this in the future.
|
| 775 |
+
|
| 776 |
+
Yunfei Chu, Jin Xu, Xiaohuan Zhou, Qian Yang, Shiliang Zhang, Zhijie Yan, Chang Zhou, and Jingren
|
| 777 |
+
Zhou. 2023. Qwen-audio: Advancing universal
|
| 778 |
+
audio understanding via unified large-scale audiolanguage models. arXiv preprint arXiv:2311.07919.
|
| 779 |
+
Leigh Clark, Philip Doyle, Diego Garaialde, Emer
|
| 780 |
+
Gilmartin, Stephan Schlögl, Jens Edlund, Matthew
|
| 781 |
+
Aylett, João Cabral, Cosmin Munteanu, Justin Edwards, and Benjamin R Cowan. 2019. The state of
|
| 782 |
+
speech in HCI: Trends, themes and challenges. Interacting with Computers, 31(4):349–371.
|
| 783 |
+
|
| 784 |
+
Ethical Considerations
|
| 785 |
+
Since LLaMA-Omni 2 is built on LLMs, it carries
|
| 786 |
+
some of the same risks as LLMs, such as the potential for factual errors or other hallucination issues
|
| 787 |
+
in its outputs. We recommend that the model’s
|
| 788 |
+
outputs be checked in practical use to ensure they
|
| 789 |
+
comply with the required standards.
|
| 790 |
+
|
| 791 |
+
Wenqian Cui, Dianzhi Yu, Xiaoqi Jiao, Ziqiao Meng,
|
| 792 |
+
Guangyan Zhang, Qichao Wang, Yiwen Guo, and Irwin King. 2024. Recent advances in speech language
|
| 793 |
+
models: A survey. arXiv preprint arXiv:2410.03751.
|
| 794 |
+
Alexandre Défossez, Laurent Mazaré, Manu Orsini,
|
| 795 |
+
Amélie Royer, Patrick Pérez, Hervé Jégou, Edouard
|
| 796 |
+
Grave, and Neil Zeghidour. 2024. Moshi: a speechtext foundation model for real-time dialogue. Technical report.
|
| 797 |
+
|
| 798 |
+
References
|
| 799 |
+
Keyu An, Qian Chen, Chong Deng, Zhihao Du,
|
| 800 |
+
Changfeng Gao, Zhifu Gao, Yue Gu, Ting He,
|
| 801 |
+
Hangrui Hu, Kai Hu, et al. 2024. Funaudiollm: Voice
|
| 802 |
+
understanding and generation foundation models for
|
| 803 |
+
natural interaction between humans and llms. arXiv
|
| 804 |
+
preprint arXiv:2407.04051.
|
| 805 |
+
|
| 806 |
+
Ning Ding, Yulin Chen, Bokai Xu, Yujia Qin, Zhi
|
| 807 |
+
Zheng, Shengding Hu, Zhiyuan Liu, Maosong Sun,
|
| 808 |
+
and Bowen Zhou. 2023. Enhancing chat language
|
| 809 |
+
models by scaling high-quality instructional conversations. arXiv preprint arXiv:2305.14233.
|
| 810 |
+
Zhihao Du, Yuxuan Wang, Qian Chen, Xian Shi, Xiang
|
| 811 |
+
Lv, Tianyu Zhao, Zhifu Gao, Yexin Yang, Changfeng
|
| 812 |
+
Gao, Hui Wang, et al. 2024. Cosyvoice 2: Scalable
|
| 813 |
+
streaming speech synthesis with large language models. arXiv preprint arXiv:2412.10117.
|
| 814 |
+
|
| 815 |
+
Jonathan Berant, Andrew Chou, Roy Frostig, and Percy
|
| 816 |
+
Liang. 2013. Semantic parsing on Freebase from
|
| 817 |
+
question-answer pairs. In Proceedings of the 2013
|
| 818 |
+
Conference on Empirical Methods in Natural Language Processing, pages 1533–1544, Seattle, Washington, USA. Association for Computational Linguistics.
|
| 819 |
+
|
| 820 |
+
Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey,
|
| 821 |
+
Abhishek Kadian, Ahmad Al-Dahle, Aiesha Letman,
|
| 822 |
+
Akhil Mathur, Alan Schelten, Amy Yang, Angela
|
| 823 |
+
Fan, et al. 2024. The llama 3 herd of models. arXiv
|
| 824 |
+
preprint arXiv:2407.21783.
|
| 825 |
+
|
| 826 |
+
Qian Chen, Yafeng Chen, Yanni Chen, Mengzhe Chen,
|
| 827 |
+
Yingda Chen, Chong Deng, Zhihao Du, Ruize Gao,
|
| 828 |
+
Changfeng Gao, Zhifu Gao, et al. 2025. Minmo: A
|
| 829 |
+
multimodal large language model for seamless voice
|
| 830 |
+
interaction. arXiv preprint arXiv:2501.06282.
|
| 831 |
+
|
| 832 |
+
Qingkai Fang, Shoutao Guo, Yan Zhou, Zhengrui Ma,
|
| 833 |
+
Shaolei Zhang, and Yang Feng. 2025. LLaMA-omni:
|
| 834 |
+
|
| 835 |
+
9
|
| 836 |
+
|
| 837 |
+
Seamless speech interaction with large language models. In The Thirteenth International Conference on
|
| 838 |
+
Learning Representations.
|
| 839 |
+
|
| 840 |
+
for advanced multilingual text-to-speech synthesis.
|
| 841 |
+
Preprint, arXiv:2411.01156.
|
| 842 |
+
Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, and Matthew Le. 2023. Flow matching for generative modeling. In The Eleventh International Conference on Learning Representations.
|
| 843 |
+
|
| 844 |
+
Yassir Fathullah, Chunyang Wu, Egor Lakomkin, Ke Li,
|
| 845 |
+
Junteng Jia, Yuan Shangguan, Jay Mahadeokar,
|
| 846 |
+
Ozlem Kalinli, Christian Fuegen, and Mike Seltzer.
|
| 847 |
+
2024. Audiochatllama: Towards general-purpose
|
| 848 |
+
speech abilities for llms. In Proceedings of the 2024
|
| 849 |
+
Conference of the North American Chapter of the
|
| 850 |
+
Association for Computational Linguistics: Human
|
| 851 |
+
Language Technologies (Volume 1: Long Papers),
|
| 852 |
+
pages 5522–5532.
|
| 853 |
+
|
| 854 |
+
Run Luo, Ting-En Lin, Haonan Zhang, Yuchuan Wu,
|
| 855 |
+
Xiong Liu, Min Yang, Yongbin Li, Longze Chen,
|
| 856 |
+
Jiaming Li, Lei Zhang, et al. 2025. Openomni:
|
| 857 |
+
Large language models pivot zero-shot omnimodal
|
| 858 |
+
alignment across language with real-time selfaware emotional speech synthesis. arXiv preprint
|
| 859 |
+
arXiv:2501.04561.
|
| 860 |
+
|
| 861 |
+
Alex Graves, Santiago Fernández, Faustino Gomez, and
|
| 862 |
+
Jürgen Schmidhuber. 2006. Connectionist temporal
|
| 863 |
+
classification: Labelling unsegmented sequence data
|
| 864 |
+
with recurrent neural networks. In Proceedings of
|
| 865 |
+
the 23rd International Conference on Machine Learning, ICML ’06, page 369–376, New York, NY, USA.
|
| 866 |
+
Association for Computing Machinery.
|
| 867 |
+
|
| 868 |
+
Ziyang Ma, Yakun Song, Chenpeng Du, Jian Cong,
|
| 869 |
+
Zhuo Chen, Yuping Wang, Yuxuan Wang, and Xie
|
| 870 |
+
Chen. 2024a. Language model can listen while
|
| 871 |
+
speaking. arXiv preprint arXiv:2408.02622.
|
| 872 |
+
|
| 873 |
+
Michael Hassid, Tal Remez, Tu Anh Nguyen, Itai Gat,
|
| 874 |
+
Alexis Conneau, Felix Kreuk, Jade Copet, Alexandre Defossez, Gabriel Synnaeve, Emmanuel Dupoux,
|
| 875 |
+
et al. 2024a. Textually pretrained speech language
|
| 876 |
+
models. Advances in Neural Information Processing
|
| 877 |
+
Systems, 36.
|
| 878 |
+
|
| 879 |
+
Ziyang Ma, Guanrou Yang, Yifan Yang, Zhifu Gao, Jiaming Wang, Zhihao Du, Fan Yu, Qian Chen, Siqi
|
| 880 |
+
Zheng, Shiliang Zhang, et al. 2024b. An embarrassingly simple approach for llm with strong asr capacity.
|
| 881 |
+
arXiv preprint arXiv:2402.08846.
|
| 882 |
+
Fabian Mentzer, David Minnen, Eirikur Agustsson, and
|
| 883 |
+
Michael Tschannen. 2024. Finite scalar quantization:
|
| 884 |
+
VQ-VAE made simple. In The Twelfth International
|
| 885 |
+
Conference on Learning Representations.
|
| 886 |
+
|
| 887 |
+
Michael Hassid, Tal Remez, Tu Anh Nguyen, Itai Gat,
|
| 888 |
+
Alexis Conneau, Felix Kreuk, Jade Copet, Alexandre Defossez, Gabriel Synnaeve, Emmanuel Dupoux,
|
| 889 |
+
et al. 2024b. Textually pretrained speech language
|
| 890 |
+
models. Advances in Neural Information Processing
|
| 891 |
+
Systems, 36.
|
| 892 |
+
|
| 893 |
+
Eliya Nachmani, Alon Levkovitch, Roy Hirsch, Julian Salazar, Chulayuth Asawaroengchai, Soroosh
|
| 894 |
+
Mariooryad, Ehud Rivlin, RJ Skerry-Ryan, and
|
| 895 |
+
Michelle Tadmor Ramanovich. 2024. Spoken
|
| 896 |
+
question answering and speech continuation using
|
| 897 |
+
spectrogram-powered LLM. In The Twelfth International Conference on Learning Representations.
|
| 898 |
+
|
| 899 |
+
Yukiya Hono, Koh Mitsuda, Tianyu Zhao, Kentaro Mitsui, Toshiaki Wakatsuki, and Kei Sawada. 2024. Integrating pre-trained speech and language models for
|
| 900 |
+
end-to-end speech recognition. In Findings of the
|
| 901 |
+
Association for Computational Linguistics ACL 2024,
|
| 902 |
+
pages 13289–13305, Bangkok, Thailand and virtual
|
| 903 |
+
meeting. Association for Computational Linguistics.
|
| 904 |
+
|
| 905 |
+
Tu Anh Nguyen, Benjamin Muller, Bokai Yu, Marta R
|
| 906 |
+
Costa-Jussa, Maha Elbayad, Sravya Popuri, PaulAmbroise Duquenne, Robin Algayres, Ruslan Mavlyutov, Itai Gat, et al. 2024. Spirit-lm: Interleaved
|
| 907 |
+
spoken and written language model. arXiv preprint
|
| 908 |
+
arXiv:2402.05755.
|
| 909 |
+
|
| 910 |
+
Shengpeng Ji, Yifu Chen, Minghui Fang, Jialong Zuo,
|
| 911 |
+
Jingyu Lu, Hanting Wang, Ziyue Jiang, Long Zhou,
|
| 912 |
+
Shujie Liu, Xize Cheng, et al. 2024. Wavchat: A
|
| 913 |
+
survey of spoken dialogue models. arXiv preprint
|
| 914 |
+
arXiv:2411.13577.
|
| 915 |
+
|
| 916 |
+
Open-Moss. 2025.
|
| 917 |
+
Speechgpt 2.0-preview.
|
| 918 |
+
https://github.com/OpenMOSS/SpeechGPT-2.
|
| 919 |
+
0-preview.
|
| 920 |
+
|
| 921 |
+
Jungil Kong, Jaehyeon Kim, and Jaekyoung Bae. 2020.
|
| 922 |
+
Hifi-gan: Generative adversarial networks for efficient and high fidelity speech synthesis. In Advances in Neural Information Processing Systems,
|
| 923 |
+
volume 33, pages 17022–17033. Curran Associates,
|
| 924 |
+
Inc.
|
| 925 |
+
|
| 926 |
+
OpenAI. 2022. Introducing chatgpt.
|
| 927 |
+
OpenAI. 2024. Hello gpt-4o.
|
| 928 |
+
|
| 929 |
+
Xuechen Li, Tianyi Zhang, Yann Dubois, Rohan Taori,
|
| 930 |
+
Ishaan Gulrajani, Carlos Guestrin, Percy Liang, and
|
| 931 |
+
Tatsunori B. Hashimoto. 2023. Alpacaeval: An automatic evaluator of instruction-following models.
|
| 932 |
+
https://github.com/tatsu-lab/alpaca_eval.
|
| 933 |
+
|
| 934 |
+
Alec Radford. 2018. Improving language understanding
|
| 935 |
+
by generative pre-training.
|
| 936 |
+
Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, and Ilya Sutskever. 2023.
|
| 937 |
+
Robust speech recognition via large-scale weak supervision. In International conference on machine
|
| 938 |
+
learning, pages 28492–28518. PMLR.
|
| 939 |
+
|
| 940 |
+
Shijia Liao, Yuxuan Wang, Tianyu Li, Yifan Cheng,
|
| 941 |
+
Ruoyi Zhang, Rongzhi Zhou, and Yijin Xing. 2024.
|
| 942 |
+
Fish-speech: Leveraging large language models
|
| 943 |
+
|
| 944 |
+
10
|
| 945 |
+
|
| 946 |
+
Paul K Rubenstein, Chulayuth Asawaroengchai,
|
| 947 |
+
Duc Dung Nguyen, Ankur Bapna, Zalán Borsos,
|
| 948 |
+
Félix de Chaumont Quitry, Peter Chen, Dalia El
|
| 949 |
+
Badawy, Wei Han, Eugene Kharitonov, et al. 2023.
|
| 950 |
+
Audiopalm: A large language model that can speak
|
| 951 |
+
and listen. arXiv preprint arXiv:2306.12925.
|
| 952 |
+
|
| 953 |
+
Aohan Zeng, Zhengxiao Du, Mingdao Liu, Kedong
|
| 954 |
+
Wang, Shengmin Jiang, Lei Zhao, Yuxiao Dong, and
|
| 955 |
+
Jie Tang. 2024a. Glm-4-voice: Towards intelligent
|
| 956 |
+
and human-like end-to-end spoken chatbot. Preprint,
|
| 957 |
+
arXiv:2412.02612.
|
| 958 |
+
Aohan Zeng, Zhengxiao Du, Mingdao Liu, Lei Zhang,
|
| 959 |
+
Shengmin Jiang, Yuxiao Dong, and Jie Tang. 2024b.
|
| 960 |
+
Scaling speech-text pre-training with synthetic interleaved data. Preprint, arXiv:2411.17607.
|
| 961 |
+
|
| 962 |
+
Takaaki Saeki, Detai Xin, Wataru Nakata, Tomoki
|
| 963 |
+
Koriyama, Shinnosuke Takamichi, and Hiroshi
|
| 964 |
+
Saruwatari. 2022. Utmos: Utokyo-sarulab system
|
| 965 |
+
for voicemos challenge 2022. In Interspeech 2022,
|
| 966 |
+
pages 4521–4525.
|
| 967 |
+
|
| 968 |
+
Aohan Zeng, Zhengxiao Du, Mingdao Liu, Lei Zhang,
|
| 969 |
+
shengmin jiang, Yuxiao Dong, and Jie Tang. 2025.
|
| 970 |
+
Scaling speech-text pre-training with synthetic interleaved data. In The Thirteenth International Conference on Learning Representations.
|
| 971 |
+
|
| 972 |
+
Changli Tang, Wenyi Yu, Guangzhi Sun, Xianzhao
|
| 973 |
+
Chen, Tian Tan, Wei Li, Lu Lu, Zejun MA, and Chao
|
| 974 |
+
Zhang. 2024. SALMONN: Towards generic hearing
|
| 975 |
+
abilities for large language models. In The Twelfth
|
| 976 |
+
International Conference on Learning Representations.
|
| 977 |
+
|
| 978 |
+
Dong Zhang, Shimin Li, Xin Zhang, Jun Zhan,
|
| 979 |
+
Pengyu Wang, Yaqian Zhou, and Xipeng Qiu. 2023.
|
| 980 |
+
SpeechGPT: Empowering large language models
|
| 981 |
+
with intrinsic cross-modal conversational abilities.
|
| 982 |
+
In Findings of the Association for Computational
|
| 983 |
+
Linguistics: EMNLP 2023, pages 15757–15773, Singapore. Association for Computational Linguistics.
|
| 984 |
+
|
| 985 |
+
Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann
|
| 986 |
+
Dubois, Xuechen Li, Carlos Guestrin, Percy Liang,
|
| 987 |
+
and Tatsunori B. Hashimoto. 2023. Stanford alpaca:
|
| 988 |
+
An instruction-following llama model. https://
|
| 989 |
+
github.com/tatsu-lab/stanford_alpaca.
|
| 990 |
+
|
| 991 |
+
Dong Zhang, Xin Zhang, Jun Zhan, Shimin Li, Yaqian
|
| 992 |
+
Zhou, and Xipeng Qiu. 2024a. Speechgpt-gen: Scaling chain-of-information speech generation. arXiv
|
| 993 |
+
preprint arXiv:2401.13527.
|
| 994 |
+
|
| 995 |
+
Qwen Team. 2024. Qwen2.5: A party of foundation
|
| 996 |
+
models.
|
| 997 |
+
A Vaswani. 2017. Attention is all you need. Advances
|
| 998 |
+
in Neural Information Processing Systems.
|
| 999 |
+
|
| 1000 |
+
Qinglin Zhang, Luyao Cheng, Chong Deng, Qian Chen,
|
| 1001 |
+
Wen Wang, Siqi Zheng, Jiaqing Liu, Hai Yu, Chaohong Tan, Zhihao Du, et al. 2024b. Omniflatten: An
|
| 1002 |
+
end-to-end gpt model for seamless voice conversation. arXiv preprint arXiv:2410.17799.
|
| 1003 |
+
|
| 1004 |
+
Chen Wang, Minpeng Liao, Zhongqiang Huang, Jinliang Lu, Junhong Wu, Yuchen Liu, Chengqing
|
| 1005 |
+
Zong, and Jiajun Zhang. 2023. Blsp: Bootstrapping language-speech pre-training via behavior
|
| 1006 |
+
alignment of continuation writing. arXiv preprint
|
| 1007 |
+
arXiv:2309.00916.
|
| 1008 |
+
|
| 1009 |
+
Xin Zhang, Xiang Lyu, Zhihao Du, Qian Chen, Dong
|
| 1010 |
+
Zhang, Hangrui Hu, Chaohong Tan, Tianyu Zhao,
|
| 1011 |
+
Yuxuan Wang, Bin Zhang, et al. 2024c. Intrinsicvoice: Empowering llms with intrinsic realtime voice interaction abilities. arXiv preprint
|
| 1012 |
+
arXiv:2410.08035.
|
| 1013 |
+
|
| 1014 |
+
Xiong Wang, Yangze Li, Chaoyou Fu, Yunhang Shen,
|
| 1015 |
+
Lei Xie, Ke Li, Xing Sun, and Long Ma. 2024.
|
| 1016 |
+
Freeze-omni: A smart and low latency speech-tospeech dialogue model with frozen llm. arXiv
|
| 1017 |
+
preprint arXiv:2411.00774.
|
| 1018 |
+
Jian Wu, Yashesh Gaur, Zhuo Chen, Long Zhou, Yimeng Zhu, Tianrui Wang, Jinyu Li, Shujie Liu,
|
| 1019 |
+
Bo Ren, Linquan Liu, et al. 2023. On decoder-only
|
| 1020 |
+
architecture for speech-to-text and large language
|
| 1021 |
+
model integration. In 2023 IEEE Automatic Speech
|
| 1022 |
+
Recognition and Understanding Workshop (ASRU),
|
| 1023 |
+
pages 1–8. IEEE.
|
| 1024 |
+
Zhifei Xie and Changqiao Wu. 2024. Mini-omni: Language models can hear, talk while thinking in streaming. arXiv preprint arXiv:2408.16725.
|
| 1025 |
+
Wenyi Yu, Changli Tang, Guangzhi Sun, Xianzhao
|
| 1026 |
+
Chen, Tian Tan, Wei Li, Lu Lu, Zejun Ma, and Chao
|
| 1027 |
+
Zhang. 2024. Connecting speech encoder and large
|
| 1028 |
+
language model for asr. In ICASSP 2024-2024 IEEE
|
| 1029 |
+
International Conference on Acoustics, Speech and
|
| 1030 |
+
Signal Processing (ICASSP), pages 12637–12641.
|
| 1031 |
+
IEEE.
|
| 1032 |
+
|
| 1033 |
+
11
|
| 1034 |
+
|
| 1035 |
+
A
|
| 1036 |
+
|
| 1037 |
+
Prompt
|
| 1038 |
+
|
| 1039 |
+
Prompt for ChatGPT Scoring (Model: GPT-4o)
|
| 1040 |
+
I need your help to evaluate the performance of several
|
| 1041 |
+
models in a speech interaction scenario. The models receive the user’s speech input and respond with speech
|
| 1042 |
+
output. For evaluation purposes, both the user’s speech
|
| 1043 |
+
input and the model’s speech response have been transcribed into text using Automatic Speech Recognition
|
| 1044 |
+
(ASR). Your task is to rate the model’s responses based
|
| 1045 |
+
on the provided user input transcription [Instruction] and
|
| 1046 |
+
the model’s output transcription [Response]. Please consider factors such as helpfulness, relevance, fluency, and
|
| 1047 |
+
suitability for speech interaction in your evaluation, and
|
| 1048 |
+
provide a single score on a scale from 1 to 5.
|
| 1049 |
+
Below are the transcription of user’s instruction and models’ response:
|
| 1050 |
+
### [Instruction]: {instruction}
|
| 1051 |
+
### [Response]: {response}
|
| 1052 |
+
After evaluating, please output the scores in JSON format:
|
| 1053 |
+
{score: ...}. You don’t need to provide any explanations.
|
| 1054 |
+
|
| 1055 |
+
B
|
| 1056 |
+
|
| 1057 |
+
Detailed Latency
|
| 1058 |
+
|
| 1059 |
+
We list the detailed latency at different stages of
|
| 1060 |
+
the model in Table 6. “LLM” refers to the latency
|
| 1061 |
+
for generating the first R text tokens, “TTS” refers
|
| 1062 |
+
to the latency for generating the first W speech
|
| 1063 |
+
tokens, and “FM+Voc” refers to the latency for
|
| 1064 |
+
generating the first speech chunk using the flow
|
| 1065 |
+
matching model and vocoder.
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<!-- BACKLOG.MD GUIDELINES START -->
|
| 3 |
+
# Instructions for the usage of Backlog.md CLI Tool
|
| 4 |
+
|
| 5 |
+
## 1. Source of Truth
|
| 6 |
+
|
| 7 |
+
- Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
|
| 8 |
+
- Every implementation decision starts with reading the corresponding Markdown task file.
|
| 9 |
+
- Project documentation is in **`backlog/docs/`**.
|
| 10 |
+
- Project decisions are in **`backlog/decisions/`**.
|
| 11 |
+
|
| 12 |
+
## 2. Defining Tasks
|
| 13 |
+
|
| 14 |
+
### Understand the Scope and the purpose
|
| 15 |
+
|
| 16 |
+
Ask questions to the user if something is not clear or ambiguous.
|
| 17 |
+
Break down the task into smaller, manageable parts if it is too large or complex.
|
| 18 |
+
|
| 19 |
+
### **Title (one liner)**
|
| 20 |
+
|
| 21 |
+
Use a clear brief title that summarizes the task.
|
| 22 |
+
|
| 23 |
+
### **Description**: (The **"why"**)
|
| 24 |
+
|
| 25 |
+
Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
|
| 26 |
+
should explain the purpose and context of the task. Code snippets should be avoided.
|
| 27 |
+
|
| 28 |
+
### **Acceptance Criteria**: (The **"what"**)
|
| 29 |
+
|
| 30 |
+
List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
|
| 31 |
+
`- [ ]`) for tracking.
|
| 32 |
+
When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
|
| 33 |
+
than step-by-step implementation details.
|
| 34 |
+
Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
|
| 35 |
+
They should be testable and confirm that the core purpose of the task is achieved.
|
| 36 |
+
**Key Principles for Good ACs:**
|
| 37 |
+
|
| 38 |
+
- **Outcome-Oriented:** Focus on the result, not the method.
|
| 39 |
+
- **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
|
| 40 |
+
- **Clear and Concise:** Unambiguous language.
|
| 41 |
+
- **Complete:** Collectively, ACs should cover the scope of the task.
|
| 42 |
+
- **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
|
| 43 |
+
|
| 44 |
+
- *Good Example:* "- [ ] User can successfully log in with valid credentials."
|
| 45 |
+
- *Good Example:* "- [ ] System processes 1000 requests per second without errors."
|
| 46 |
+
- *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
|
| 47 |
+
|
| 48 |
+
### Task file
|
| 49 |
+
|
| 50 |
+
Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
|
| 51 |
+
`task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
|
| 52 |
+
|
| 53 |
+
### Task Breakdown Strategy
|
| 54 |
+
|
| 55 |
+
When breaking down features:
|
| 56 |
+
|
| 57 |
+
1. Identify the foundational components first
|
| 58 |
+
2. Create tasks in dependency order (foundations before features)
|
| 59 |
+
3. Ensure each task delivers value independently
|
| 60 |
+
4. Avoid creating tasks that block each other
|
| 61 |
+
|
| 62 |
+
### Additional task requirements
|
| 63 |
+
|
| 64 |
+
- Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
|
| 65 |
+
Each task should represent a single unit of work that can be completed in a single PR.
|
| 66 |
+
|
| 67 |
+
- **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
|
| 68 |
+
previous
|
| 69 |
+
tasks (id < current task id).
|
| 70 |
+
|
| 71 |
+
- When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
|
| 72 |
+
Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
|
| 73 |
+
schema".
|
| 74 |
+
Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
|
| 75 |
+
schema", task 3: "Add API endpoint for user data".
|
| 76 |
+
|
| 77 |
+
## 3. Recommended Task Anatomy
|
| 78 |
+
|
| 79 |
+
```markdown
|
| 80 |
+
# task‑42 - Add GraphQL resolver
|
| 81 |
+
|
| 82 |
+
## Description (the why)
|
| 83 |
+
|
| 84 |
+
Short, imperative explanation of the goal of the task and why it is needed.
|
| 85 |
+
|
| 86 |
+
## Acceptance Criteria (the what)
|
| 87 |
+
|
| 88 |
+
- [ ] Resolver returns correct data for happy path
|
| 89 |
+
- [ ] Error response matches REST
|
| 90 |
+
- [ ] P95 latency ≤ 50 ms under 100 RPS
|
| 91 |
+
|
| 92 |
+
## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
|
| 93 |
+
|
| 94 |
+
1. Research existing GraphQL resolver patterns
|
| 95 |
+
2. Implement basic resolver with error handling
|
| 96 |
+
3. Add performance monitoring
|
| 97 |
+
4. Write unit and integration tests
|
| 98 |
+
5. Benchmark performance under load
|
| 99 |
+
|
| 100 |
+
## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
|
| 101 |
+
|
| 102 |
+
- Approach taken
|
| 103 |
+
- Features implemented or modified
|
| 104 |
+
- Technical decisions and trade-offs
|
| 105 |
+
- Modified or added files
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## 6. Implementing Tasks
|
| 109 |
+
|
| 110 |
+
Mandatory sections for every task:
|
| 111 |
+
|
| 112 |
+
- **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
|
| 113 |
+
change after the task is created, **the implementation plan must be added only after putting the task in progress**
|
| 114 |
+
and before starting working on the task.
|
| 115 |
+
- **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
|
| 116 |
+
section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
|
| 117 |
+
concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
|
| 118 |
+
others will read the code to understand the details.
|
| 119 |
+
|
| 120 |
+
**IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
|
| 121 |
+
implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
|
| 122 |
+
|
| 123 |
+
## 2. Typical Workflow
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
# 1 Identify work
|
| 127 |
+
backlog task list -s "To Do" --plain
|
| 128 |
+
|
| 129 |
+
# 2 Read details & documentation
|
| 130 |
+
backlog task 42 --plain
|
| 131 |
+
# Read also all documentation files in `backlog/docs/` directory.
|
| 132 |
+
# Read also all decision files in `backlog/decisions/` directory.
|
| 133 |
+
|
| 134 |
+
# 3 Start work: assign yourself & move column
|
| 135 |
+
backlog task edit 42 -a @{yourself} -s "In Progress"
|
| 136 |
+
|
| 137 |
+
# 4 Add implementation plan before starting
|
| 138 |
+
backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
|
| 139 |
+
|
| 140 |
+
# 5 Break work down if needed by creating subtasks or additional tasks
|
| 141 |
+
backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
|
| 142 |
+
|
| 143 |
+
# 6 Complete and mark Done
|
| 144 |
+
backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### 7. Final Steps Before Marking a Task as Done
|
| 148 |
+
|
| 149 |
+
Always ensure you have:
|
| 150 |
+
|
| 151 |
+
1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
|
| 152 |
+
2. ✅ Added an `## Implementation Notes` section documenting your approach
|
| 153 |
+
3. ✅ Run all tests and linting checks
|
| 154 |
+
4. ✅ Updated relevant documentation
|
| 155 |
+
|
| 156 |
+
## 8. Definition of Done (DoD)
|
| 157 |
+
|
| 158 |
+
A task is **Done** only when **ALL** of the following are complete:
|
| 159 |
+
|
| 160 |
+
1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
|
| 161 |
+
2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
|
| 162 |
+
3. **Automated tests** (unit + integration) cover new logic.
|
| 163 |
+
4. **Static analysis**: linter & formatter succeed.
|
| 164 |
+
5. **Documentation**:
|
| 165 |
+
- All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
|
| 166 |
+
- Task file **MUST** have an `## Implementation Notes` section added summarising:
|
| 167 |
+
- Approach taken
|
| 168 |
+
- Features implemented or modified
|
| 169 |
+
- Technical decisions and trade-offs
|
| 170 |
+
- Modified or added files
|
| 171 |
+
6. **Review**: self review code.
|
| 172 |
+
7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
|
| 173 |
+
8. **No regressions**: performance, security and licence checks green.
|
| 174 |
+
|
| 175 |
+
⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
|
| 176 |
+
|
| 177 |
+
## 9. Handy CLI Commands
|
| 178 |
+
|
| 179 |
+
| Action | Example |
|
| 180 |
+
|-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 181 |
+
| Create task | `backlog task create "Add OAuth System"` |
|
| 182 |
+
| Create with description | `backlog task create "Feature" -d "Add authentication system"` |
|
| 183 |
+
| Create with assignee | `backlog task create "Feature" -a @sara` |
|
| 184 |
+
| Create with status | `backlog task create "Feature" -s "In Progress"` |
|
| 185 |
+
| Create with labels | `backlog task create "Feature" -l auth,backend` |
|
| 186 |
+
| Create with priority | `backlog task create "Feature" --priority high` |
|
| 187 |
+
| Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
|
| 188 |
+
| Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
|
| 189 |
+
| Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
|
| 190 |
+
| Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
|
| 191 |
+
| Create sub task | `backlog task create -p 14 "Add Login with Google"` |
|
| 192 |
+
| Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
|
| 193 |
+
| List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
|
| 194 |
+
| List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
|
| 195 |
+
| View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
|
| 196 |
+
| View (AI mode) | `backlog task 7 --plain` |
|
| 197 |
+
| Edit | `backlog task edit 7 -a @sara -l auth,backend` |
|
| 198 |
+
| Add plan | `backlog task edit 7 --plan "Implementation approach"` |
|
| 199 |
+
| Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
|
| 200 |
+
| Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
|
| 201 |
+
| Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
|
| 202 |
+
| Archive | `backlog task archive 7` |
|
| 203 |
+
| Create draft | `backlog task create "Feature" --draft` |
|
| 204 |
+
| Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
|
| 205 |
+
| Demote to draft | `backlog task demote <id>` |
|
| 206 |
+
|
| 207 |
+
Full help: `backlog --help`
|
| 208 |
+
|
| 209 |
+
## 10. Tips for AI Agents
|
| 210 |
+
|
| 211 |
+
- **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
|
| 212 |
+
interactive UI.
|
| 213 |
+
- When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
|
| 214 |
+
|
| 215 |
+
<!-- BACKLOG.MD GUIDELINES END -->
|
COSYVOICE2_CHANGES.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CosyVoice2 Model Changes Documentation
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
This document captures the modifications made to the CosyVoice2 model integration for the LLaMA-Omni2 voice assistant system.
|
| 5 |
+
|
| 6 |
+
## Key Changes
|
| 7 |
+
|
| 8 |
+
### 1. Configuration Files
|
| 9 |
+
- **cosyvoice.yaml**: Primary configuration file used by the model
|
| 10 |
+
- **cosyvoice2.yaml**: Original CosyVoice2 configuration
|
| 11 |
+
- **cosyvoice_fixed.yaml**: Configuration with `mix_ratio` parameter removed to fix compatibility issues
|
| 12 |
+
|
| 13 |
+
### 2. Model Files Structure
|
| 14 |
+
```
|
| 15 |
+
models/cosyvoice2/
|
| 16 |
+
├── CosyVoice-BlankEN/ # English tokenizer model
|
| 17 |
+
├── campplus.onnx # Speaker embedding model
|
| 18 |
+
├── flow.decoder.estimator.fp32.onnx # Flow decoder
|
| 19 |
+
├── flow.pt # Flow model weights
|
| 20 |
+
├── hift.pt # HiFi-GAN vocoder weights
|
| 21 |
+
├── llm.pt # Language model weights
|
| 22 |
+
├── speech_tokenizer_v1.onnx # Speech tokenizer v1
|
| 23 |
+
└── speech_tokenizer_v2.onnx # Speech tokenizer v2 (new addition)
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 3. Code Modifications
|
| 27 |
+
|
| 28 |
+
#### cosyvoice/flow/flow.py
|
| 29 |
+
- Modified to handle CosyVoice2 model architecture
|
| 30 |
+
- Updated MaskedDiffWithXvec class for compatibility
|
| 31 |
+
- Adjusted decoder configuration parameters
|
| 32 |
+
|
| 33 |
+
#### llama_omni2/serve/flow_inference.py
|
| 34 |
+
- Updated SpeechDecoder class to properly load CosyVoice2 models
|
| 35 |
+
- Changed configuration loading to use 'cosyvoice.yaml' instead of fallback logic
|
| 36 |
+
- Added support for speech_tokenizer_v2.onnx
|
| 37 |
+
|
| 38 |
+
### 4. Integration Points
|
| 39 |
+
- **Model Path**: `models/cosyvoice2/` or `models/cosy2_decoder/`
|
| 40 |
+
- **Frontend**: CosyVoiceFrontEnd handles tokenization and feature extraction
|
| 41 |
+
- **Vocoder**: Uses the model as vocoder in gradio_web_server.py with `--vocoder-dir` flag
|
| 42 |
+
|
| 43 |
+
## Setup Requirements
|
| 44 |
+
|
| 45 |
+
### Model Download
|
| 46 |
+
```bash
|
| 47 |
+
# Download CosyVoice2 model from HuggingFace
|
| 48 |
+
python -c "
|
| 49 |
+
from huggingface_hub import snapshot_download
|
| 50 |
+
snapshot_download(
|
| 51 |
+
repo_id='FunAudioLLM/CosyVoice2-0.5B',
|
| 52 |
+
local_dir='models/cosyvoice2',
|
| 53 |
+
local_dir_use_symlinks=False
|
| 54 |
+
)"
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Configuration Fix
|
| 58 |
+
The original cosyvoice2.yaml may contain a `mix_ratio` parameter that causes issues. This is fixed by:
|
| 59 |
+
1. Copying cosyvoice2.yaml to cosyvoice.yaml
|
| 60 |
+
2. Removing the mix_ratio parameter
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice.yaml
|
| 64 |
+
grep -v "mix_ratio" models/cosyvoice2/cosyvoice.yaml > models/cosyvoice2/cosyvoice_fixed.yaml
|
| 65 |
+
mv models/cosyvoice2/cosyvoice_fixed.yaml models/cosyvoice2/cosyvoice.yaml
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Usage in LLaMA-Omni2
|
| 69 |
+
|
| 70 |
+
Start the Gradio server with CosyVoice2 as vocoder:
|
| 71 |
+
```bash
|
| 72 |
+
python -m llama_omni2.serve.gradio_web_server \
|
| 73 |
+
--controller http://localhost:10000 \
|
| 74 |
+
--port 8000 \
|
| 75 |
+
--vocoder-dir models/cosyvoice2
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Known Issues and Solutions
|
| 79 |
+
|
| 80 |
+
1. **mix_ratio parameter error**: Remove from configuration file
|
| 81 |
+
2. **Missing cosyvoice.yaml**: Copy from cosyvoice2.yaml
|
| 82 |
+
3. **Tokenizer compatibility**: Ensure both v1 and v2 tokenizers are present
|
| 83 |
+
|
| 84 |
+
## Performance Notes
|
| 85 |
+
- CosyVoice2-0.5B is optimized for faster inference
|
| 86 |
+
- Supports both Chinese and English text-to-speech
|
| 87 |
+
- Compatible with streaming generation for real-time applications
|
GEMINI.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<!-- BACKLOG.MD GUIDELINES START -->
|
| 3 |
+
# Instructions for the usage of Backlog.md CLI Tool
|
| 4 |
+
|
| 5 |
+
## 1. Source of Truth
|
| 6 |
+
|
| 7 |
+
- Tasks live under **`backlog/tasks/`** (drafts under **`backlog/drafts/`**).
|
| 8 |
+
- Every implementation decision starts with reading the corresponding Markdown task file.
|
| 9 |
+
- Project documentation is in **`backlog/docs/`**.
|
| 10 |
+
- Project decisions are in **`backlog/decisions/`**.
|
| 11 |
+
|
| 12 |
+
## 2. Defining Tasks
|
| 13 |
+
|
| 14 |
+
### Understand the Scope and the purpose
|
| 15 |
+
|
| 16 |
+
Ask questions to the user if something is not clear or ambiguous.
|
| 17 |
+
Break down the task into smaller, manageable parts if it is too large or complex.
|
| 18 |
+
|
| 19 |
+
### **Title (one liner)**
|
| 20 |
+
|
| 21 |
+
Use a clear brief title that summarizes the task.
|
| 22 |
+
|
| 23 |
+
### **Description**: (The **"why"**)
|
| 24 |
+
|
| 25 |
+
Provide a concise summary of the task purpose and its goal. Do not add implementation details here. It
|
| 26 |
+
should explain the purpose and context of the task. Code snippets should be avoided.
|
| 27 |
+
|
| 28 |
+
### **Acceptance Criteria**: (The **"what"**)
|
| 29 |
+
|
| 30 |
+
List specific, measurable outcomes that define what means to reach the goal from the description. Use checkboxes (
|
| 31 |
+
`- [ ]`) for tracking.
|
| 32 |
+
When defining `## Acceptance Criteria` for a task, focus on **outcomes, behaviors, and verifiable requirements** rather
|
| 33 |
+
than step-by-step implementation details.
|
| 34 |
+
Acceptance Criteria (AC) define *what* conditions must be met for the task to be considered complete.
|
| 35 |
+
They should be testable and confirm that the core purpose of the task is achieved.
|
| 36 |
+
**Key Principles for Good ACs:**
|
| 37 |
+
|
| 38 |
+
- **Outcome-Oriented:** Focus on the result, not the method.
|
| 39 |
+
- **Testable/Verifiable:** Each criterion should be something that can be objectively tested or verified.
|
| 40 |
+
- **Clear and Concise:** Unambiguous language.
|
| 41 |
+
- **Complete:** Collectively, ACs should cover the scope of the task.
|
| 42 |
+
- **User-Focused (where applicable):** Frame ACs from the perspective of the end-user or the system's external behavior.
|
| 43 |
+
|
| 44 |
+
- *Good Example:* "- [ ] User can successfully log in with valid credentials."
|
| 45 |
+
- *Good Example:* "- [ ] System processes 1000 requests per second without errors."
|
| 46 |
+
- *Bad Example (Implementation Step):* "- [ ] Add a new function `handleLogin()` in `auth.ts`."
|
| 47 |
+
|
| 48 |
+
### Task file
|
| 49 |
+
|
| 50 |
+
Once a task is created it will be stored in `backlog/tasks/` directory as a Markdown file with the format
|
| 51 |
+
`task-<id> - <title>.md` (e.g. `task-42 - Add GraphQL resolver.md`).
|
| 52 |
+
|
| 53 |
+
### Task Breakdown Strategy
|
| 54 |
+
|
| 55 |
+
When breaking down features:
|
| 56 |
+
|
| 57 |
+
1. Identify the foundational components first
|
| 58 |
+
2. Create tasks in dependency order (foundations before features)
|
| 59 |
+
3. Ensure each task delivers value independently
|
| 60 |
+
4. Avoid creating tasks that block each other
|
| 61 |
+
|
| 62 |
+
### Additional task requirements
|
| 63 |
+
|
| 64 |
+
- Tasks must be **atomic** and **testable**. If a task is too large, break it down into smaller subtasks.
|
| 65 |
+
Each task should represent a single unit of work that can be completed in a single PR.
|
| 66 |
+
|
| 67 |
+
- **Never** reference tasks that are to be done in the future or that are not yet created. You can only reference
|
| 68 |
+
previous
|
| 69 |
+
tasks (id < current task id).
|
| 70 |
+
|
| 71 |
+
- When creating multiple tasks, ensure they are **independent** and they do not depend on future tasks.
|
| 72 |
+
Example of wrong tasks splitting: task 1: "Add API endpoint for user data", task 2: "Define the user model and DB
|
| 73 |
+
schema".
|
| 74 |
+
Example of correct tasks splitting: task 1: "Add system for handling API requests", task 2: "Add user model and DB
|
| 75 |
+
schema", task 3: "Add API endpoint for user data".
|
| 76 |
+
|
| 77 |
+
## 3. Recommended Task Anatomy
|
| 78 |
+
|
| 79 |
+
```markdown
|
| 80 |
+
# task‑42 - Add GraphQL resolver
|
| 81 |
+
|
| 82 |
+
## Description (the why)
|
| 83 |
+
|
| 84 |
+
Short, imperative explanation of the goal of the task and why it is needed.
|
| 85 |
+
|
| 86 |
+
## Acceptance Criteria (the what)
|
| 87 |
+
|
| 88 |
+
- [ ] Resolver returns correct data for happy path
|
| 89 |
+
- [ ] Error response matches REST
|
| 90 |
+
- [ ] P95 latency ≤ 50 ms under 100 RPS
|
| 91 |
+
|
| 92 |
+
## Implementation Plan (the how) (added after putting the task in progress but before implementing any code change)
|
| 93 |
+
|
| 94 |
+
1. Research existing GraphQL resolver patterns
|
| 95 |
+
2. Implement basic resolver with error handling
|
| 96 |
+
3. Add performance monitoring
|
| 97 |
+
4. Write unit and integration tests
|
| 98 |
+
5. Benchmark performance under load
|
| 99 |
+
|
| 100 |
+
## Implementation Notes (imagine this is the PR description) (only added after finishing the code implementation of a task)
|
| 101 |
+
|
| 102 |
+
- Approach taken
|
| 103 |
+
- Features implemented or modified
|
| 104 |
+
- Technical decisions and trade-offs
|
| 105 |
+
- Modified or added files
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## 6. Implementing Tasks
|
| 109 |
+
|
| 110 |
+
Mandatory sections for every task:
|
| 111 |
+
|
| 112 |
+
- **Implementation Plan**: (The **"how"**) Outline the steps to achieve the task. Because the implementation details may
|
| 113 |
+
change after the task is created, **the implementation plan must be added only after putting the task in progress**
|
| 114 |
+
and before starting working on the task.
|
| 115 |
+
- **Implementation Notes**: Start with a brief summary of what has been implemented. Document your approach, decisions, challenges, and any deviations from the plan. This
|
| 116 |
+
section is added after you are done working on the task. It should summarize what you did and why you did it. Keep it
|
| 117 |
+
concise but informative. Imagine this is the PR description. Make it brief, explain the core changes and assume that
|
| 118 |
+
others will read the code to understand the details.
|
| 119 |
+
|
| 120 |
+
**IMPORTANT**: Do not implement anything else that deviates from the **Acceptance Criteria**. If you need to
|
| 121 |
+
implement something that is not in the AC, update the AC first and then implement it or create a new task for it.
|
| 122 |
+
|
| 123 |
+
## 2. Typical Workflow
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
# 1 Identify work
|
| 127 |
+
backlog task list -s "To Do" --plain
|
| 128 |
+
|
| 129 |
+
# 2 Read details & documentation
|
| 130 |
+
backlog task 42 --plain
|
| 131 |
+
# Read also all documentation files in `backlog/docs/` directory.
|
| 132 |
+
# Read also all decision files in `backlog/decisions/` directory.
|
| 133 |
+
|
| 134 |
+
# 3 Start work: assign yourself & move column
|
| 135 |
+
backlog task edit 42 -a @{yourself} -s "In Progress"
|
| 136 |
+
|
| 137 |
+
# 4 Add implementation plan before starting
|
| 138 |
+
backlog task edit 42 --plan "1. Analyze current implementation\n2. Identify bottlenecks\n3. Refactor in phases"
|
| 139 |
+
|
| 140 |
+
# 5 Break work down if needed by creating subtasks or additional tasks
|
| 141 |
+
backlog task create "Refactor DB layer" -p 42 -a @{yourself} -d "Description" --ac "Tests pass,Performance improved"
|
| 142 |
+
|
| 143 |
+
# 6 Complete and mark Done
|
| 144 |
+
backlog task edit 42 -s Done --notes "Implemented GraphQL resolver with error handling and performance monitoring"
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### 7. Final Steps Before Marking a Task as Done
|
| 148 |
+
|
| 149 |
+
Always ensure you have:
|
| 150 |
+
|
| 151 |
+
1. ✅ Marked all acceptance criteria as completed (change `- [ ]` to `- [x]`)
|
| 152 |
+
2. ✅ Added an `## Implementation Notes` section documenting your approach
|
| 153 |
+
3. ✅ Run all tests and linting checks
|
| 154 |
+
4. ✅ Updated relevant documentation
|
| 155 |
+
|
| 156 |
+
## 8. Definition of Done (DoD)
|
| 157 |
+
|
| 158 |
+
A task is **Done** only when **ALL** of the following are complete:
|
| 159 |
+
|
| 160 |
+
1. **Acceptance criteria** checklist in the task file is fully checked (all `- [ ]` changed to `- [x]`).
|
| 161 |
+
2. **Implementation plan** was followed or deviations were documented in Implementation Notes.
|
| 162 |
+
3. **Automated tests** (unit + integration) cover new logic.
|
| 163 |
+
4. **Static analysis**: linter & formatter succeed.
|
| 164 |
+
5. **Documentation**:
|
| 165 |
+
- All relevant docs updated (any relevant README file, backlog/docs, backlog/decisions, etc.).
|
| 166 |
+
- Task file **MUST** have an `## Implementation Notes` section added summarising:
|
| 167 |
+
- Approach taken
|
| 168 |
+
- Features implemented or modified
|
| 169 |
+
- Technical decisions and trade-offs
|
| 170 |
+
- Modified or added files
|
| 171 |
+
6. **Review**: self review code.
|
| 172 |
+
7. **Task hygiene**: status set to **Done** via CLI (`backlog task edit <id> -s Done`).
|
| 173 |
+
8. **No regressions**: performance, security and licence checks green.
|
| 174 |
+
|
| 175 |
+
⚠️ **IMPORTANT**: Never mark a task as Done without completing ALL items above.
|
| 176 |
+
|
| 177 |
+
## 9. Handy CLI Commands
|
| 178 |
+
|
| 179 |
+
| Action | Example |
|
| 180 |
+
|-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 181 |
+
| Create task | `backlog task create "Add OAuth System"` |
|
| 182 |
+
| Create with description | `backlog task create "Feature" -d "Add authentication system"` |
|
| 183 |
+
| Create with assignee | `backlog task create "Feature" -a @sara` |
|
| 184 |
+
| Create with status | `backlog task create "Feature" -s "In Progress"` |
|
| 185 |
+
| Create with labels | `backlog task create "Feature" -l auth,backend` |
|
| 186 |
+
| Create with priority | `backlog task create "Feature" --priority high` |
|
| 187 |
+
| Create with plan | `backlog task create "Feature" --plan "1. Research\n2. Implement"` |
|
| 188 |
+
| Create with AC | `backlog task create "Feature" --ac "Must work,Must be tested"` |
|
| 189 |
+
| Create with notes | `backlog task create "Feature" --notes "Started initial research"` |
|
| 190 |
+
| Create with deps | `backlog task create "Feature" --dep task-1,task-2` |
|
| 191 |
+
| Create sub task | `backlog task create -p 14 "Add Login with Google"` |
|
| 192 |
+
| Create (all options) | `backlog task create "Feature" -d "Description" -a @sara -s "To Do" -l auth --priority high --ac "Must work" --notes "Initial setup done" --dep task-1 -p 14` |
|
| 193 |
+
| List tasks | `backlog task list [-s <status>] [-a <assignee>] [-p <parent>]` |
|
| 194 |
+
| List by parent | `backlog task list --parent 42` or `backlog task list -p task-42` |
|
| 195 |
+
| View detail | `backlog task 7` (interactive UI, press 'E' to edit in editor) |
|
| 196 |
+
| View (AI mode) | `backlog task 7 --plain` |
|
| 197 |
+
| Edit | `backlog task edit 7 -a @sara -l auth,backend` |
|
| 198 |
+
| Add plan | `backlog task edit 7 --plan "Implementation approach"` |
|
| 199 |
+
| Add AC | `backlog task edit 7 --ac "New criterion,Another one"` |
|
| 200 |
+
| Add notes | `backlog task edit 7 --notes "Completed X, working on Y"` |
|
| 201 |
+
| Add deps | `backlog task edit 7 --dep task-1 --dep task-2` |
|
| 202 |
+
| Archive | `backlog task archive 7` |
|
| 203 |
+
| Create draft | `backlog task create "Feature" --draft` |
|
| 204 |
+
| Draft flow | `backlog draft create "Spike GraphQL"` → `backlog draft promote 3.1` |
|
| 205 |
+
| Demote to draft | `backlog task demote <id>` |
|
| 206 |
+
|
| 207 |
+
Full help: `backlog --help`
|
| 208 |
+
|
| 209 |
+
## 10. Tips for AI Agents
|
| 210 |
+
|
| 211 |
+
- **Always use `--plain` flag** when listing or viewing tasks for AI-friendly text output instead of using Backlog.md
|
| 212 |
+
interactive UI.
|
| 213 |
+
- When users mention to create a task, they mean to create a task using Backlog.md CLI tool.
|
| 214 |
+
|
| 215 |
+
<!-- BACKLOG.MD GUIDELINES END -->
|
LLaMA-Omni2-3B/README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🦙🎧 LLaMA-Omni 2: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis
|
| 2 |
+
|
| 3 |
+
> **Authors: [Qingkai Fang](https://fangqingkai.github.io/), [Yan Zhou](https://zhouyan19.github.io/zhouyan/), [Shoutao Guo](https://scholar.google.com/citations?hl=en&user=XwHtPyAAAAAJ), [Shaolei Zhang](https://zhangshaolei1998.github.io/), [Yang Feng*](https://people.ucas.edu.cn/~yangfeng?language=en)**
|
| 4 |
+
|
| 5 |
+
[](https://arxiv.org/abs/2505.02625)
|
| 6 |
+
[](https://github.com/ictnlp/LLaMA-Omni2)
|
| 7 |
+
[](https://huggingface.co/collections/ICTNLP/llama-omni-67fdfb852c60470175e36e9c)
|
| 8 |
+
[](https://huggingface.co/datasets/ICTNLP/Multiturn-Speech-Conversations)
|
| 9 |
+
|
| 10 |
+
LLaMA-Omni 2 is a series of speech-language models built on the Qwen2.5-0.5B/1.5B/3B/7B/14B/32B-Instruct models. Similar to [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni), it can generate both text and speech responses simultaneously, enabling high-quality and low-latency speech interaction. With the newly introduced streaming autoregressive speech decoder, LLaMA-Omni 2 achieves higher speech quality compared to LLaMA-Omni.
|
| 11 |
+
|
| 12 |
+
<div align="center"><img src="images/llama-omni2.png" width="75%"/></div>
|
| 13 |
+
|
| 14 |
+
## 🔥 News
|
| 15 |
+
|
| 16 |
+
- [25/05] LLaMA-Omni 2 is accepted at ACL 2025 main conference!
|
| 17 |
+
|
| 18 |
+
## Install
|
| 19 |
+
|
| 20 |
+
1. Clone this repository.
|
| 21 |
+
|
| 22 |
+
```shell
|
| 23 |
+
git clone https://github.com/ictnlp/LLaMA-Omni2
|
| 24 |
+
cd LLaMA-Omni2
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2. Install packages.
|
| 28 |
+
|
| 29 |
+
```shell
|
| 30 |
+
conda create -n llama-omni2 python=3.10
|
| 31 |
+
conda activate llama-omni2
|
| 32 |
+
pip install -e .
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Quick Start
|
| 36 |
+
|
| 37 |
+
1. Download the `Whisper-large-v3` model.
|
| 38 |
+
|
| 39 |
+
```shell
|
| 40 |
+
import whisper
|
| 41 |
+
model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
2. Download the flow-matching model and vocoder of `CosyVoice 2`.
|
| 45 |
+
|
| 46 |
+
```shell
|
| 47 |
+
huggingface-cli download --resume-download ICTNLP/cosy2_decoder --local-dir models/cosy2_decoder
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
> [!Tip]
|
| 51 |
+
> If you’re experiencing unstable connections to Hugging Face from within China, you can try setting the following in your command line:
|
| 52 |
+
>
|
| 53 |
+
> ```shell
|
| 54 |
+
> export HF_ENDPOINT=https://hf-mirror.com
|
| 55 |
+
> ```
|
| 56 |
+
|
| 57 |
+
3. Download the LLaMA-Omni2 series models from Hugging Face. `LLaMA-Omni2-0.5B/1.5B/3B/7B/14B` support **English only**, while `LLaMA-Omni2-0.5B/1.5B/3B/7B/14B/32B-Bilingual` support **both English and Chinese**.
|
| 58 |
+
|
| 59 |
+
```shell
|
| 60 |
+
model_name=LLaMA-Omni2-7B-Bilingual
|
| 61 |
+
huggingface-cli download --resume-download ICTNLP/$model_name --local-dir models/$model_name
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
| LLaMA-Omni2 | LLaMA-Omni2-Bilingual |
|
| 65 |
+
| --------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
| 66 |
+
| 🤗 [LLaMA-Omni2-0.5B](https://huggingface.co/ICTNLP/LLaMA-Omni2-0.5B) | 🤗 [LLaMA-Omni2-0.5B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-0.5B-Bilingual) |
|
| 67 |
+
| 🤗 [LLaMA-Omni2-1.5B](https://huggingface.co/ICTNLP/LLaMA-Omni2-1.5B) | 🤗 [LLaMA-Omni2-1.5B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-1.5B-Bilingual) |
|
| 68 |
+
| 🤗 [LLaMA-Omni2-3B](https://huggingface.co/ICTNLP/LLaMA-Omni2-3B) | 🤗 [LLaMA-Omni2-3B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-3B-Bilingual) |
|
| 69 |
+
| 🤗 [LLaMA-Omni2-7B](https://huggingface.co/ICTNLP/LLaMA-Omni2-7B) | 🤗 [LLaMA-Omni2-7B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-7B-Bilingual) |
|
| 70 |
+
| 🤗 [LLaMA-Omni2-14B](https://huggingface.co/ICTNLP/LLaMA-Omni2-14B) | 🤗 [LLaMA-Omni2-14B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-14B-Bilingual) |
|
| 71 |
+
| - | 🤗 [LLaMA-Omni2-32B-Bilingual](https://huggingface.co/ICTNLP/LLaMA-Omni2-32B-Bilingual) |
|
| 72 |
+
|
| 73 |
+
## Gradio Demo
|
| 74 |
+
|
| 75 |
+
1. Launch a controller.
|
| 76 |
+
|
| 77 |
+
```shell
|
| 78 |
+
python -m llama_omni2.serve.controller --host 0.0.0.0 --port 10000
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
2. Launch a gradio web server.
|
| 82 |
+
|
| 83 |
+
```shell
|
| 84 |
+
python -m llama_omni2.serve.gradio_web_server --controller http://localhost:10000 --port 8000 --vocoder-dir models/cosy2_decoder
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
3. Launch a model worker.
|
| 88 |
+
|
| 89 |
+
```shell
|
| 90 |
+
python -m llama_omni2.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path models/$model_name --model-name $model_name
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
4. Visit [http://localhost:8000/](http://localhost:8000/) and interact with LLaMA-Omni2!
|
| 94 |
+
|
| 95 |
+
## Local Inference
|
| 96 |
+
|
| 97 |
+
```shell
|
| 98 |
+
output_dir=examples/$model_name
|
| 99 |
+
mkdir -p $output_dir
|
| 100 |
+
|
| 101 |
+
python llama_omni2/inference/run_llama_omni2.py \
|
| 102 |
+
--model_path models/$model_name \
|
| 103 |
+
--question_file examples/questions.json \
|
| 104 |
+
--answer_file $output_dir/answers.jsonl \
|
| 105 |
+
--temperature 0 \
|
| 106 |
+
--s2s
|
| 107 |
+
|
| 108 |
+
python llama_omni2/inference/run_cosy2_decoder.py \
|
| 109 |
+
--input-path $output_dir/answers.jsonl \
|
| 110 |
+
--output-dir $output_dir/wav \
|
| 111 |
+
--lang en
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## LICENSE
|
| 115 |
+
|
| 116 |
+
Our code is released under the Apache-2.0 License. Our model is intended for academic research purposes only and may **NOT** be used for commercial purposes.
|
| 117 |
+
|
| 118 |
+
You are free to use, modify, and distribute this model in academic settings, provided that the following conditions are met:
|
| 119 |
+
|
| 120 |
+
- **Non-commercial use**: The model may not be used for any commercial purposes.
|
| 121 |
+
- **Citation**: If you use this model in your research, please cite the original work.
|
| 122 |
+
|
| 123 |
+
### Commercial Use Restriction
|
| 124 |
+
|
| 125 |
+
For any commercial use inquiries or to obtain a commercial license, please contact `fengyang@ict.ac.cn`.
|
| 126 |
+
|
| 127 |
+
## Acknowledgements
|
| 128 |
+
|
| 129 |
+
- [CosyVoice 2](https://github.com/FunAudioLLM/CosyVoice): We use the pretrained speech tokenizer, flow-matching model and vocoder of CosyVoice 2.
|
| 130 |
+
- [SLAM-LLM](https://github.com/X-LANCE/SLAM-LLM): We borrow some code about speech encoder and speech adaptor.
|
| 131 |
+
|
| 132 |
+
## Citation
|
| 133 |
+
|
| 134 |
+
If you have any questions, please feel free to submit an issue or contact `fangqingkai21b@ict.ac.cn`.
|
| 135 |
+
|
| 136 |
+
If our work is useful for you, please cite as:
|
| 137 |
+
|
| 138 |
+
```
|
| 139 |
+
@inproceedings{
|
| 140 |
+
fang2025llamaomni2,
|
| 141 |
+
title={{LL}a{MA}-{O}mni 2: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis},
|
| 142 |
+
author={Fang, Qingkai and Zhou, Yan and Guo, Shoutao and Zhang, Shaolei and Feng, Yang},
|
| 143 |
+
booktitle = {Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics},
|
| 144 |
+
year={2025}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
@inproceedings{
|
| 148 |
+
fang2025llamaomni,
|
| 149 |
+
title={{LL}a{MA}-{O}mni: Seamless Speech Interaction with Large Language Models},
|
| 150 |
+
author={Qingkai Fang and Shoutao Guo and Yan Zhou and Zhengrui Ma and Shaolei Zhang and Yang Feng},
|
| 151 |
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
| 152 |
+
year={2025},
|
| 153 |
+
url={https://openreview.net/forum?id=PYmrUQmMEw}
|
| 154 |
+
}
|
| 155 |
+
```
|
LLaMA-Omni2-3B/added_tokens.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</tool_call>": 151658,
|
| 3 |
+
"<speech>": 151665,
|
| 4 |
+
"<tool_call>": 151657,
|
| 5 |
+
"<|box_end|>": 151649,
|
| 6 |
+
"<|box_start|>": 151648,
|
| 7 |
+
"<|endoftext|>": 151643,
|
| 8 |
+
"<|file_sep|>": 151664,
|
| 9 |
+
"<|fim_middle|>": 151660,
|
| 10 |
+
"<|fim_pad|>": 151662,
|
| 11 |
+
"<|fim_prefix|>": 151659,
|
| 12 |
+
"<|fim_suffix|>": 151661,
|
| 13 |
+
"<|im_end|>": 151645,
|
| 14 |
+
"<|im_start|>": 151644,
|
| 15 |
+
"<|image_pad|>": 151655,
|
| 16 |
+
"<|object_ref_end|>": 151647,
|
| 17 |
+
"<|object_ref_start|>": 151646,
|
| 18 |
+
"<|quad_end|>": 151651,
|
| 19 |
+
"<|quad_start|>": 151650,
|
| 20 |
+
"<|repo_name|>": 151663,
|
| 21 |
+
"<|video_pad|>": 151656,
|
| 22 |
+
"<|vision_end|>": 151653,
|
| 23 |
+
"<|vision_pad|>": 151654,
|
| 24 |
+
"<|vision_start|>": 151652
|
| 25 |
+
}
|
LLaMA-Omni2-3B/config.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "LLaMA-Omni2-3B",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Omni2Speech2SQwen2ForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"eos_token_id": 151645,
|
| 9 |
+
"hidden_act": "silu",
|
| 10 |
+
"hidden_size": 2048,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"intermediate_size": 11008,
|
| 13 |
+
"max_position_embeddings": 32768,
|
| 14 |
+
"max_window_layers": 70,
|
| 15 |
+
"model_type": "omni2_speech2s_qwen2",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_hidden_layers": 36,
|
| 18 |
+
"num_key_value_heads": 2,
|
| 19 |
+
"rms_norm_eps": 1e-06,
|
| 20 |
+
"rope_theta": 1000000.0,
|
| 21 |
+
"sliding_window": null,
|
| 22 |
+
"speech_encoder": "models/speech_encoder/large-v3.pt",
|
| 23 |
+
"speech_encoder_ds_rate": 5,
|
| 24 |
+
"speech_encoder_hidden_size": 1280,
|
| 25 |
+
"speech_encoder_type": "whisper",
|
| 26 |
+
"speech_generator": {
|
| 27 |
+
"architectures": [
|
| 28 |
+
"Qwen2ForCausalLM"
|
| 29 |
+
],
|
| 30 |
+
"attention_dropout": 0.0,
|
| 31 |
+
"bos_token_id": 151643,
|
| 32 |
+
"eos_token_id": 151643,
|
| 33 |
+
"hidden_act": "silu",
|
| 34 |
+
"hidden_size": 896,
|
| 35 |
+
"initializer_range": 0.02,
|
| 36 |
+
"intermediate_size": 4864,
|
| 37 |
+
"max_position_embeddings": 32768,
|
| 38 |
+
"max_window_layers": 24,
|
| 39 |
+
"model_type": "qwen2",
|
| 40 |
+
"num_attention_heads": 14,
|
| 41 |
+
"num_hidden_layers": 24,
|
| 42 |
+
"num_key_value_heads": 2,
|
| 43 |
+
"rms_norm_eps": 1e-06,
|
| 44 |
+
"rope_theta": 1000000.0,
|
| 45 |
+
"sliding_window": null,
|
| 46 |
+
"tie_word_embeddings": true,
|
| 47 |
+
"torch_dtype": "bfloat16",
|
| 48 |
+
"transformers_version": "4.43.4",
|
| 49 |
+
"use_cache": true,
|
| 50 |
+
"use_mrope": false,
|
| 51 |
+
"use_sliding_window": false,
|
| 52 |
+
"vocab_size": 158227
|
| 53 |
+
},
|
| 54 |
+
"speech_projector_type": "linear",
|
| 55 |
+
"stream_params": "(3,10)",
|
| 56 |
+
"tie_word_embeddings": true,
|
| 57 |
+
"tokenizer_model_max_length": 4096,
|
| 58 |
+
"tokenizer_padding_side": "right",
|
| 59 |
+
"torch_dtype": "bfloat16",
|
| 60 |
+
"transformers_version": "4.43.4",
|
| 61 |
+
"unit_vocab_size": 6561,
|
| 62 |
+
"use_cache": true,
|
| 63 |
+
"use_sliding_window": false,
|
| 64 |
+
"vocab_size": 151936
|
| 65 |
+
}
|
LLaMA-Omni2-3B/generation_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_implementation": "flash_attention_2",
|
| 3 |
+
"bos_token_id": 151643,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
151645,
|
| 7 |
+
151643
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 151643,
|
| 10 |
+
"repetition_penalty": 1.05,
|
| 11 |
+
"temperature": 0.7,
|
| 12 |
+
"top_k": 20,
|
| 13 |
+
"top_p": 0.8,
|
| 14 |
+
"transformers_version": "4.43.4"
|
| 15 |
+
}
|
LLaMA-Omni2-3B/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc4b7fda5d470353f675e0410724af00479eb09b3c81c5648bad35ac97904665
|
| 3 |
+
size 4957560304
|
LLaMA-Omni2-3B/model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60dd465c6e6ceac492af19fbdbe11dd9fb4104b2a71c3825fba38a8a0427ed94
|
| 3 |
+
size 4455567096
|
LLaMA-Omni2-3B/model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/special_tokens_map.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|im_end|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": "<|endoftext|>"
|
| 25 |
+
}
|
LLaMA-Omni2-3B/tokenizer_config.json
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<speech>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": true,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
}
|
| 189 |
+
},
|
| 190 |
+
"additional_special_tokens": [
|
| 191 |
+
"<|im_start|>",
|
| 192 |
+
"<|im_end|>",
|
| 193 |
+
"<|object_ref_start|>",
|
| 194 |
+
"<|object_ref_end|>",
|
| 195 |
+
"<|box_start|>",
|
| 196 |
+
"<|box_end|>",
|
| 197 |
+
"<|quad_start|>",
|
| 198 |
+
"<|quad_end|>",
|
| 199 |
+
"<|vision_start|>",
|
| 200 |
+
"<|vision_end|>",
|
| 201 |
+
"<|vision_pad|>",
|
| 202 |
+
"<|image_pad|>",
|
| 203 |
+
"<|video_pad|>"
|
| 204 |
+
],
|
| 205 |
+
"bos_token": null,
|
| 206 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
| 207 |
+
"clean_up_tokenization_spaces": false,
|
| 208 |
+
"eos_token": "<|im_end|>",
|
| 209 |
+
"errors": "replace",
|
| 210 |
+
"model_max_length": 4096,
|
| 211 |
+
"pad_token": "<|endoftext|>",
|
| 212 |
+
"padding_side": "right",
|
| 213 |
+
"split_special_tokens": false,
|
| 214 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 215 |
+
"unk_token": null
|
| 216 |
+
}
|
LLaMA-Omni2-3B/tts_tokenizer/added_tokens.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/tts_tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/tts_tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": "<|endoftext|>"
|
| 25 |
+
}
|
LLaMA-Omni2-3B/tts_tokenizer/tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/tts_tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LLaMA-Omni2-3B/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎙️🤖 Goodspace Voice Agent: LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis
|
| 2 |
+
|
| 3 |
+
> **Powered by advanced speech-language models and streaming synthesis technology**
|
| 4 |
+
|
| 5 |
+
[](https://github.com/goodspace/voice-agent)
|
| 6 |
+
[](https://huggingface.co/collections/goodspace/voice-agent)
|
| 7 |
+
[](https://huggingface.co/datasets/goodspace/speech-conversations)
|
| 8 |
+
|
| 9 |
+
Goodspace Voice Agent is a cutting-edge series of speech-language models built on the Qwen2.5-0.5B/1.5B/3B/7B/14B/32B-Instruct models. It can generate both text and speech responses simultaneously, enabling high-quality and low-latency speech interaction. With the streaming autoregressive speech decoder, Goodspace Voice Agent achieves exceptional speech quality and natural conversation flow.
|
| 10 |
+
|
| 11 |
+
<div align="center"><img src="images/llama-omni2.png" width="75%"/></div>
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## 🔥 News
|
| 15 |
+
|
| 16 |
+
- Goodspace Voice Agent - Advanced real-time voice interaction system now available!
|
| 17 |
+
|
| 18 |
+
## Install
|
| 19 |
+
|
| 20 |
+
1. Clone this repository.
|
| 21 |
+
|
| 22 |
+
```shell
|
| 23 |
+
git clone https://github.com/goodspace/voice-agent
|
| 24 |
+
cd voice-agent
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2. Install packages.
|
| 28 |
+
|
| 29 |
+
```shell
|
| 30 |
+
conda create -n goodspace-voice python=3.10
|
| 31 |
+
conda activate goodspace-voice
|
| 32 |
+
pip install -e .
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Quick Start
|
| 36 |
+
|
| 37 |
+
1. Download the `Whisper-large-v3` model.
|
| 38 |
+
|
| 39 |
+
```shell
|
| 40 |
+
import whisper
|
| 41 |
+
model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
2. Download the flow-matching model and vocoder of `CosyVoice 2`.
|
| 45 |
+
|
| 46 |
+
```shell
|
| 47 |
+
huggingface-cli download --resume-download goodspace/cosy2_decoder --local-dir models/cosy2_decoder
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
> [!Tip]
|
| 51 |
+
> If you’re experiencing unstable connections to Hugging Face from within China, you can try setting the following in your command line:
|
| 52 |
+
>
|
| 53 |
+
> ```shell
|
| 54 |
+
> export HF_ENDPOINT=https://hf-mirror.com
|
| 55 |
+
> ```
|
| 56 |
+
|
| 57 |
+
3. Download the Goodspace Voice Agent models from Hugging Face. `GoodspaceVoice-0.5B/1.5B/3B/7B/14B` support **English only**, while `GoodspaceVoice-0.5B/1.5B/3B/7B/14B/32B-Bilingual` support **both English and Chinese**.
|
| 58 |
+
|
| 59 |
+
```shell
|
| 60 |
+
model_name=GoodspaceVoice-7B-Bilingual
|
| 61 |
+
huggingface-cli download --resume-download goodspace/$model_name --local-dir models/$model_name
|
| 62 |
+
```
|
| 63 |
+
## Gradio Demo
|
| 64 |
+
|
| 65 |
+
1. Launch a controller.
|
| 66 |
+
|
| 67 |
+
```shell
|
| 68 |
+
python -m goodspace_voice.serve.controller --host 0.0.0.0 --port 10000
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
2. Launch a gradio web server.
|
| 72 |
+
|
| 73 |
+
```shell
|
| 74 |
+
python -m goodspace_voice.serve.gradio_web_server --controller http://localhost:10000 --port 8000 --vocoder-dir models/cosy2_decoder
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
3. Launch a model worker.
|
| 78 |
+
|
| 79 |
+
```shell
|
| 80 |
+
python -m goodspace_voice.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path models/$model_name --model-name $model_name
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
4. Visit [http://localhost:8000/](http://localhost:8000/) and interact with GoodspaceVoice!
|
| 84 |
+
|
| 85 |
+
## Local Inference
|
| 86 |
+
|
| 87 |
+
```shell
|
| 88 |
+
output_dir=examples/$model_name
|
| 89 |
+
mkdir -p $output_dir
|
| 90 |
+
|
| 91 |
+
python goodspace_voice/inference/run_goodspace_voice.py \
|
| 92 |
+
--model_path models/$model_name \
|
| 93 |
+
--question_file examples/questions.json \
|
| 94 |
+
--answer_file $output_dir/answers.jsonl \
|
| 95 |
+
--temperature 0 \
|
| 96 |
+
--s2s
|
| 97 |
+
|
| 98 |
+
python goodspace_voice/inference/run_cosy2_decoder.py \
|
| 99 |
+
--input-path $output_dir/answers.jsonl \
|
| 100 |
+
--output-dir $output_dir/wav \
|
| 101 |
+
--lang en
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## LICENSE
|
| 105 |
+
|
| 106 |
+
The Goodspace Voice Agent is released under the Apache-2.0 License.
|
| 107 |
+
|
| 108 |
+
### Commercial Use
|
| 109 |
+
|
| 110 |
+
For commercial use inquiries or licensing information, please contact the Goodspace team.
|
| 111 |
+
|
| 112 |
+
## Acknowledgements
|
| 113 |
+
|
| 114 |
+
- [CosyVoice 2](https://github.com/FunAudioLLM/CosyVoice): We use the pretrained speech tokenizer, flow-matching model and vocoder of CosyVoice 2.
|
| 115 |
+
- [SLAM-LLM](https://github.com/X-LANCE/SLAM-LLM): We borrow some code about speech encoder and speech adaptor.
|
| 116 |
+
- Based on the research work from LLaMA-Omni2 paper.
|
| 117 |
+
|
| 118 |
+
## Support
|
| 119 |
+
|
| 120 |
+
If you have any questions or issues, please feel free to submit an issue on our GitHub repository.
|
| 121 |
+
|
| 122 |
+
## Contributing
|
| 123 |
+
|
| 124 |
+
We welcome contributions! Please see our contributing guidelines for more information.
|
SETUP_GUIDE.md
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLaMA-Omni2 Voice Assistant Setup Guide
|
| 2 |
+
|
| 3 |
+
This guide provides comprehensive instructions for reproducing the exact environment and setup for the LLaMA-Omni2 voice assistant with CosyVoice2 integration.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Ubuntu/Linux system with CUDA-capable GPU
|
| 8 |
+
- CUDA 12.1 or higher installed
|
| 9 |
+
- Miniconda or Anaconda installed
|
| 10 |
+
- At least 16GB RAM and 20GB free disk space
|
| 11 |
+
- Python 3.10
|
| 12 |
+
|
| 13 |
+
## Environment Setup Options
|
| 14 |
+
|
| 15 |
+
### Option 1: Using Conda Environment File (Recommended)
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# Create environment from comprehensive yml file
|
| 19 |
+
conda env create -f environment-comprehensive.yml
|
| 20 |
+
|
| 21 |
+
# Activate the environment
|
| 22 |
+
conda activate gsva-python310
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Option 2: Using Frozen Requirements
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Create a new conda environment
|
| 29 |
+
conda create -n gsva-python310 python=3.10 -y
|
| 30 |
+
conda activate gsva-python310
|
| 31 |
+
|
| 32 |
+
# Install from frozen requirements
|
| 33 |
+
pip install -r requirements-frozen-new.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Option 3: Manual Setup Using Script
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Run the complete setup script
|
| 40 |
+
bash script.sh
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Detailed Manual Setup
|
| 44 |
+
|
| 45 |
+
### 1. Create and Activate Conda Environment
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
source /home/azureuser/miniconda3/etc/profile.d/conda.sh
|
| 49 |
+
conda create -n gsva-python310 python=3.10 -y
|
| 50 |
+
conda activate gsva-python310
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### 2. Install Basic Dependencies
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
pip install Cython numpy==1.26.4
|
| 57 |
+
pip install packaging wheel setuptools==69.5.1
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 3. Install the Package
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
# Install in development mode
|
| 64 |
+
pip install -e .
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### 4. Install Core Dependencies
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# Essential packages
|
| 71 |
+
pip install huggingface_hub==0.25.1
|
| 72 |
+
pip install uvicorn openai-whisper fastapi
|
| 73 |
+
pip install hf_transfer ninja
|
| 74 |
+
|
| 75 |
+
# Gradio for web interface
|
| 76 |
+
pip install gradio==5.3.0 gradio_client==1.4.2
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### 5. Setup CUDA Environment
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
# Link CUDA installation
|
| 83 |
+
sudo rm -rf /usr/local/cuda
|
| 84 |
+
sudo ln -s /usr/local/cuda-12.6 /usr/local/cuda
|
| 85 |
+
export PATH=/usr/local/cuda/bin:$PATH
|
| 86 |
+
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### 6. Install PyTorch with CUDA Support
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### 7. Install Flash Attention
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### 8. Install Transformers and Audio Libraries
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
# Specific version for LLaMA-Omni2 compatibility
|
| 105 |
+
pip install transformers==4.43.4
|
| 106 |
+
|
| 107 |
+
# Audio processing libraries
|
| 108 |
+
pip install matcha-tts --no-build-isolation
|
| 109 |
+
pip install git+https://github.com/FunAudioLLM/CosyVoice.git
|
| 110 |
+
|
| 111 |
+
# Additional dependencies
|
| 112 |
+
pip install conformer onnxruntime hyperpyyaml==1.2.2 ruamel.yaml
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Model Downloads
|
| 116 |
+
|
| 117 |
+
### 1. Download LLaMA-Omni2 Model
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
mkdir -p models
|
| 121 |
+
huggingface-cli download ICTNLP/LLaMA-Omni2-3B --local-dir models/LLaMA-Omni2-3B
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 2. Download CosyVoice2 Model
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
mkdir -p models/cosyvoice2
|
| 128 |
+
python -c "
|
| 129 |
+
from huggingface_hub import snapshot_download
|
| 130 |
+
import os
|
| 131 |
+
os.makedirs('models/cosyvoice2', exist_ok=True)
|
| 132 |
+
snapshot_download(
|
| 133 |
+
repo_id='FunAudioLLM/CosyVoice2-0.5B',
|
| 134 |
+
local_dir='models/cosyvoice2',
|
| 135 |
+
local_dir_use_symlinks=False
|
| 136 |
+
)
|
| 137 |
+
"
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### 3. Fix CosyVoice Configuration
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# Create backup
|
| 144 |
+
cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice2.yaml.backup
|
| 145 |
+
|
| 146 |
+
# Copy to expected filename
|
| 147 |
+
cp models/cosyvoice2/cosyvoice2.yaml models/cosyvoice2/cosyvoice.yaml
|
| 148 |
+
|
| 149 |
+
# Remove problematic parameter
|
| 150 |
+
grep -v "mix_ratio" models/cosyvoice2/cosyvoice.yaml > models/cosyvoice2/cosyvoice_fixed.yaml
|
| 151 |
+
mv models/cosyvoice2/cosyvoice_fixed.yaml models/cosyvoice2/cosyvoice.yaml
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Running the Services
|
| 155 |
+
|
| 156 |
+
### 1. Start Controller
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
nohup python -m llama_omni2.serve.controller \
|
| 160 |
+
--host 0.0.0.0 \
|
| 161 |
+
--port 10000 > controller.log 2>&1 &
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### 2. Start Model Worker
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
nohup python -m llama_omni2.serve.model_worker \
|
| 168 |
+
--host 0.0.0.0 \
|
| 169 |
+
--controller http://localhost:10000 \
|
| 170 |
+
--port 40000 \
|
| 171 |
+
--worker http://localhost:40000 \
|
| 172 |
+
--model-path models/LLaMA-Omni2-3B \
|
| 173 |
+
--model-name LLaMA-Omni2-3B > worker.log 2>&1 &
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
### 3. Start Gradio Web Server
|
| 177 |
+
|
| 178 |
+
With CosyVoice2 vocoder:
|
| 179 |
+
```bash
|
| 180 |
+
python -m llama_omni2.serve.gradio_web_server \
|
| 181 |
+
--controller http://localhost:10000 \
|
| 182 |
+
--port 8000 \
|
| 183 |
+
--vocoder-dir models/cosyvoice2
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
Without vocoder (fallback):
|
| 187 |
+
```bash
|
| 188 |
+
python -m llama_omni2.serve.gradio_web_server \
|
| 189 |
+
--controller http://localhost:10000 \
|
| 190 |
+
--port 8000
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
## Monitoring Services
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
# Check controller logs
|
| 197 |
+
tail -f controller.log
|
| 198 |
+
|
| 199 |
+
# Check model worker logs
|
| 200 |
+
tail -f worker.log
|
| 201 |
+
|
| 202 |
+
# Access web UI
|
| 203 |
+
# Open browser at http://localhost:8000
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
## Troubleshooting
|
| 207 |
+
|
| 208 |
+
### Common Issues
|
| 209 |
+
|
| 210 |
+
1. **CUDA not found**: Ensure CUDA paths are exported correctly
|
| 211 |
+
2. **Flash attention build fails**: Use `MAX_JOBS=4` to limit parallel compilation
|
| 212 |
+
3. **CosyVoice mix_ratio error**: Follow the configuration fix steps above
|
| 213 |
+
4. **Port already in use**: Kill existing processes or use different ports
|
| 214 |
+
|
| 215 |
+
### Killing Services
|
| 216 |
+
|
| 217 |
+
```bash
|
| 218 |
+
# Find and kill Python processes
|
| 219 |
+
ps aux | grep python | grep -E "(controller|model_worker|gradio_web_server)" | awk '{print $2}' | xargs -r kill
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
## Project Structure
|
| 223 |
+
|
| 224 |
+
```
|
| 225 |
+
voiceagents/
|
| 226 |
+
├── llama_omni2/ # Main application code
|
| 227 |
+
├── cosyvoice/ # CosyVoice integration
|
| 228 |
+
├── models/ # Downloaded models
|
| 229 |
+
│ ├── LLaMA-Omni2-3B/
|
| 230 |
+
│ └── cosyvoice2/
|
| 231 |
+
├── examples/ # Sample audio files
|
| 232 |
+
├── script.sh # Setup script
|
| 233 |
+
├── pyproject.toml # Project configuration
|
| 234 |
+
├── requirements-frozen-new.txt # Frozen dependencies
|
| 235 |
+
├── environment-comprehensive.yml # Conda environment
|
| 236 |
+
└── SETUP_GUIDE.md # This file
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
## Environment Variables
|
| 240 |
+
|
| 241 |
+
Set these in your `.bashrc` or `.zshrc`:
|
| 242 |
+
|
| 243 |
+
```bash
|
| 244 |
+
export PATH=/usr/local/cuda/bin:$PATH
|
| 245 |
+
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
| 246 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 247 |
+
export HF_HOME=~/.cache/huggingface
|
| 248 |
+
export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0"
|
| 249 |
+
export MAX_JOBS=4
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
## Version Information
|
| 253 |
+
|
| 254 |
+
- Python: 3.10
|
| 255 |
+
- PyTorch: 2.3.1
|
| 256 |
+
- Transformers: 4.43.4
|
| 257 |
+
- Gradio: 5.3.0
|
| 258 |
+
- CUDA: 12.1+
|
| 259 |
+
- CosyVoice2: 0.5B model
|
| 260 |
+
|
| 261 |
+
## Additional Notes
|
| 262 |
+
|
| 263 |
+
- The setup has been tested on Ubuntu with NVIDIA GPUs
|
| 264 |
+
- Ensure sufficient GPU memory (8GB+ recommended)
|
| 265 |
+
- For production deployment, consider using systemd services
|
| 266 |
+
- Regular backups of models and configurations are recommended
|
| 267 |
+
|
| 268 |
+
## Support
|
| 269 |
+
|
| 270 |
+
For issues or questions:
|
| 271 |
+
- Check the logs in controller.log, worker.log
|
| 272 |
+
- Ensure all dependencies are correctly installed
|
| 273 |
+
- Verify CUDA is properly configured
|
| 274 |
+
- Review the COSYVOICE2_CHANGES.md for model-specific details
|
controller.log.2025-08-16
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-08-16 15:21:01 | INFO | controller | args: Namespace(host='0.0.0.0', port=10000, dispatch_method='shortest_queue')
|
| 2 |
+
2025-08-16 15:21:01 | INFO | controller | Init controller
|
| 3 |
+
2025-08-16 15:21:01 | ERROR | stderr | INFO: Started server process [32029]
|
| 4 |
+
2025-08-16 15:21:01 | ERROR | stderr | INFO: Waiting for application startup.
|
| 5 |
+
2025-08-16 15:21:01 | ERROR | stderr | INFO: Application startup complete.
|
| 6 |
+
2025-08-16 15:21:01 | ERROR | stderr | INFO: Uvicorn running on http://0.0.0.0:10000 (Press CTRL+C to quit)
|
cosyvoice/__init__.py
ADDED
|
File without changes
|
cosyvoice/bin/average_model.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import argparse
|
| 18 |
+
import glob
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description='average model')
|
| 26 |
+
parser.add_argument('--dst_model', required=True, help='averaged model')
|
| 27 |
+
parser.add_argument('--src_path',
|
| 28 |
+
required=True,
|
| 29 |
+
help='src model path for average')
|
| 30 |
+
parser.add_argument('--val_best',
|
| 31 |
+
action="store_true",
|
| 32 |
+
help='averaged model')
|
| 33 |
+
parser.add_argument('--num',
|
| 34 |
+
default=5,
|
| 35 |
+
type=int,
|
| 36 |
+
help='nums for averaged model')
|
| 37 |
+
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
print(args)
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
args = get_args()
|
| 45 |
+
val_scores = []
|
| 46 |
+
if args.val_best:
|
| 47 |
+
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
| 48 |
+
yamls = [
|
| 49 |
+
f for f in yamls
|
| 50 |
+
if not (os.path.basename(f).startswith('train')
|
| 51 |
+
or os.path.basename(f).startswith('init'))
|
| 52 |
+
]
|
| 53 |
+
for y in yamls:
|
| 54 |
+
with open(y, 'r') as f:
|
| 55 |
+
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
| 56 |
+
loss = float(dic_yaml['loss_dict']['loss'])
|
| 57 |
+
epoch = int(dic_yaml['epoch'])
|
| 58 |
+
step = int(dic_yaml['step'])
|
| 59 |
+
tag = dic_yaml['tag']
|
| 60 |
+
val_scores += [[epoch, step, loss, tag]]
|
| 61 |
+
sorted_val_scores = sorted(val_scores,
|
| 62 |
+
key=lambda x: x[2],
|
| 63 |
+
reverse=False)
|
| 64 |
+
print("best val (epoch, step, loss, tag) = " +
|
| 65 |
+
str(sorted_val_scores[:args.num]))
|
| 66 |
+
path_list = [
|
| 67 |
+
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
| 68 |
+
for score in sorted_val_scores[:args.num]
|
| 69 |
+
]
|
| 70 |
+
print(path_list)
|
| 71 |
+
avg = {}
|
| 72 |
+
num = args.num
|
| 73 |
+
assert num == len(path_list)
|
| 74 |
+
for path in path_list:
|
| 75 |
+
print('Processing {}'.format(path))
|
| 76 |
+
states = torch.load(path, map_location=torch.device('cpu'))
|
| 77 |
+
for k in states.keys():
|
| 78 |
+
if k not in avg.keys():
|
| 79 |
+
avg[k] = states[k].clone()
|
| 80 |
+
else:
|
| 81 |
+
avg[k] += states[k]
|
| 82 |
+
# average
|
| 83 |
+
for k in avg.keys():
|
| 84 |
+
if avg[k] is not None:
|
| 85 |
+
# pytorch 1.6 use true_divide instead of /=
|
| 86 |
+
avg[k] = torch.true_divide(avg[k], num)
|
| 87 |
+
print('Saving to {}'.format(args.dst_model))
|
| 88 |
+
torch.save(avg, args.dst_model)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
main()
|
cosyvoice/bin/export_jit.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import torch
|
| 23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 26 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 31 |
+
parser.add_argument('--model_dir',
|
| 32 |
+
type=str,
|
| 33 |
+
default='pretrained_models/CosyVoice-300M',
|
| 34 |
+
help='local path')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
print(args)
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
args = get_args()
|
| 42 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 43 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 44 |
+
|
| 45 |
+
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
| 46 |
+
torch._C._jit_set_profiling_mode(False)
|
| 47 |
+
torch._C._jit_set_profiling_executor(False)
|
| 48 |
+
|
| 49 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
| 50 |
+
|
| 51 |
+
# 1. export llm text_encoder
|
| 52 |
+
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
| 53 |
+
script = torch.jit.script(llm_text_encoder)
|
| 54 |
+
script = torch.jit.freeze(script)
|
| 55 |
+
script = torch.jit.optimize_for_inference(script)
|
| 56 |
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
| 57 |
+
|
| 58 |
+
# 2. export llm llm
|
| 59 |
+
llm_llm = cosyvoice.model.llm.llm.half()
|
| 60 |
+
script = torch.jit.script(llm_llm)
|
| 61 |
+
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
| 62 |
+
script = torch.jit.optimize_for_inference(script)
|
| 63 |
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
| 64 |
+
|
| 65 |
+
# 3. export flow encoder
|
| 66 |
+
flow_encoder = cosyvoice.model.flow.encoder
|
| 67 |
+
script = torch.jit.script(flow_encoder)
|
| 68 |
+
script = torch.jit.freeze(script)
|
| 69 |
+
script = torch.jit.optimize_for_inference(script)
|
| 70 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == '__main__':
|
| 74 |
+
main()
|
cosyvoice/bin/export_onnx.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
| 2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import print_function
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import logging
|
| 20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import onnxruntime
|
| 24 |
+
import random
|
| 25 |
+
import torch
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
| 29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
| 30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
| 34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
| 36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
| 38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
| 39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
| 40 |
+
return x, mask, mu, t, spks, cond
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_args():
|
| 44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
| 45 |
+
parser.add_argument('--model_dir',
|
| 46 |
+
type=str,
|
| 47 |
+
default='pretrained_models/CosyVoice-300M',
|
| 48 |
+
help='local path')
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
print(args)
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
args = get_args()
|
| 56 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 58 |
+
|
| 59 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
| 60 |
+
|
| 61 |
+
# 1. export flow decoder estimator
|
| 62 |
+
estimator = cosyvoice.model.flow.decoder.estimator
|
| 63 |
+
|
| 64 |
+
device = cosyvoice.model.device
|
| 65 |
+
batch_size, seq_len = 1, 256
|
| 66 |
+
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
| 67 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
| 68 |
+
torch.onnx.export(
|
| 69 |
+
estimator,
|
| 70 |
+
(x, mask, mu, t, spks, cond),
|
| 71 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 72 |
+
export_params=True,
|
| 73 |
+
opset_version=18,
|
| 74 |
+
do_constant_folding=True,
|
| 75 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
| 76 |
+
output_names=['estimator_out'],
|
| 77 |
+
dynamic_axes={
|
| 78 |
+
'x': {0: 'batch_size', 2: 'seq_len'},
|
| 79 |
+
'mask': {0: 'batch_size', 2: 'seq_len'},
|
| 80 |
+
'mu': {0: 'batch_size', 2: 'seq_len'},
|
| 81 |
+
'cond': {0: 'batch_size', 2: 'seq_len'},
|
| 82 |
+
't': {0: 'batch_size'},
|
| 83 |
+
'spks': {0: 'batch_size'},
|
| 84 |
+
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# 2. test computation consistency
|
| 89 |
+
option = onnxruntime.SessionOptions()
|
| 90 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 91 |
+
option.intra_op_num_threads = 1
|
| 92 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
| 93 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
| 94 |
+
sess_options=option, providers=providers)
|
| 95 |
+
|
| 96 |
+
for _ in tqdm(range(10)):
|
| 97 |
+
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
| 98 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
| 99 |
+
ort_inputs = {
|
| 100 |
+
'x': x.cpu().numpy(),
|
| 101 |
+
'mask': mask.cpu().numpy(),
|
| 102 |
+
'mu': mu.cpu().numpy(),
|
| 103 |
+
't': t.cpu().numpy(),
|
| 104 |
+
'spks': spks.cpu().numpy(),
|
| 105 |
+
'cond': cond.cpu().numpy()
|
| 106 |
+
}
|
| 107 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
| 108 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
cosyvoice/bin/export_trt.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
| 3 |
+
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
|
| 4 |
+
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
|
| 5 |
+
TRT_DIR=<YOUR_TRT_DIR>
|
| 6 |
+
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
| 7 |
+
|
| 8 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
| 9 |
+
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
cosyvoice/bin/inference.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
import os
|
| 21 |
+
import torch
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
import torchaudio
|
| 24 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
| 27 |
+
from cosyvoice.dataset.dataset import Dataset
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_args():
|
| 31 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
| 32 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 33 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
| 34 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
| 35 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
| 36 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
| 37 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
| 38 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
| 39 |
+
parser.add_argument('--gpu',
|
| 40 |
+
type=int,
|
| 41 |
+
default=-1,
|
| 42 |
+
help='gpu id for this rank, -1 for cpu')
|
| 43 |
+
parser.add_argument('--mode',
|
| 44 |
+
default='sft',
|
| 45 |
+
choices=['sft', 'zero_shot'],
|
| 46 |
+
help='inference mode')
|
| 47 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
print(args)
|
| 50 |
+
return args
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
args = get_args()
|
| 55 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 57 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
| 58 |
+
|
| 59 |
+
# Init cosyvoice models from configs
|
| 60 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
| 61 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
| 62 |
+
with open(args.config, 'r') as f:
|
| 63 |
+
configs = load_hyperpyyaml(f)
|
| 64 |
+
|
| 65 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
| 66 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
| 67 |
+
|
| 68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
| 69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
| 70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
| 71 |
+
|
| 72 |
+
del configs
|
| 73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
| 74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
| 75 |
+
f = open(fn, 'w')
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
| 78 |
+
utts = batch["utts"]
|
| 79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
| 80 |
+
text_token = batch["text_token"].to(device)
|
| 81 |
+
text_token_len = batch["text_token_len"].to(device)
|
| 82 |
+
tts_index = batch["tts_index"]
|
| 83 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
| 84 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
| 85 |
+
speech_token = batch["speech_token"].to(device)
|
| 86 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
| 87 |
+
speech_feat = batch["speech_feat"].to(device)
|
| 88 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
| 89 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
| 90 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
| 91 |
+
if args.mode == 'sft':
|
| 92 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 93 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
| 94 |
+
else:
|
| 95 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 96 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
| 97 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 98 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 99 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 100 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
| 101 |
+
tts_speeches = []
|
| 102 |
+
for model_output in model.tts(**model_input):
|
| 103 |
+
tts_speeches.append(model_output['tts_speech'])
|
| 104 |
+
tts_speeches = torch.concat(tts_speeches, dim=1)
|
| 105 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
| 106 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
| 107 |
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
| 108 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
| 109 |
+
f.flush()
|
| 110 |
+
f.close()
|
| 111 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
main()
|
cosyvoice/bin/train.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import print_function
|
| 16 |
+
import argparse
|
| 17 |
+
import datetime
|
| 18 |
+
import logging
|
| 19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
import os
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
import deepspeed
|
| 25 |
+
|
| 26 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 27 |
+
|
| 28 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
| 29 |
+
|
| 30 |
+
from cosyvoice.utils.executor import Executor
|
| 31 |
+
from cosyvoice.utils.train_utils import (
|
| 32 |
+
init_distributed,
|
| 33 |
+
init_dataset_and_dataloader,
|
| 34 |
+
init_optimizer_and_scheduler,
|
| 35 |
+
init_summarywriter, save_model,
|
| 36 |
+
wrap_cuda_model, check_modify_and_save_config)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_args():
|
| 40 |
+
parser = argparse.ArgumentParser(description='training your network')
|
| 41 |
+
parser.add_argument('--train_engine',
|
| 42 |
+
default='torch_ddp',
|
| 43 |
+
choices=['torch_ddp', 'deepspeed'],
|
| 44 |
+
help='Engine for paralleled training')
|
| 45 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
| 46 |
+
parser.add_argument('--config', required=True, help='config file')
|
| 47 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
| 48 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
| 49 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
| 50 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
| 51 |
+
parser.add_argument('--tensorboard_dir',
|
| 52 |
+
default='tensorboard',
|
| 53 |
+
help='tensorboard log dir')
|
| 54 |
+
parser.add_argument('--ddp.dist_backend',
|
| 55 |
+
dest='dist_backend',
|
| 56 |
+
default='nccl',
|
| 57 |
+
choices=['nccl', 'gloo'],
|
| 58 |
+
help='distributed backend')
|
| 59 |
+
parser.add_argument('--num_workers',
|
| 60 |
+
default=0,
|
| 61 |
+
type=int,
|
| 62 |
+
help='num of subprocess workers for reading')
|
| 63 |
+
parser.add_argument('--prefetch',
|
| 64 |
+
default=100,
|
| 65 |
+
type=int,
|
| 66 |
+
help='prefetch number')
|
| 67 |
+
parser.add_argument('--pin_memory',
|
| 68 |
+
action='store_true',
|
| 69 |
+
default=False,
|
| 70 |
+
help='Use pinned memory buffers used for reading')
|
| 71 |
+
parser.add_argument('--use_amp',
|
| 72 |
+
action='store_true',
|
| 73 |
+
default=False,
|
| 74 |
+
help='Use automatic mixed precision training')
|
| 75 |
+
parser.add_argument('--deepspeed.save_states',
|
| 76 |
+
dest='save_states',
|
| 77 |
+
default='model_only',
|
| 78 |
+
choices=['model_only', 'model+optimizer'],
|
| 79 |
+
help='save model/optimizer states')
|
| 80 |
+
parser.add_argument('--timeout',
|
| 81 |
+
default=60,
|
| 82 |
+
type=int,
|
| 83 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
| 84 |
+
parser = deepspeed.add_config_arguments(parser)
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
return args
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@record
|
| 90 |
+
def main():
|
| 91 |
+
args = get_args()
|
| 92 |
+
logging.basicConfig(level=logging.DEBUG,
|
| 93 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
| 94 |
+
# gan train has some special initialization logic
|
| 95 |
+
gan = True if args.model == 'hifigan' else False
|
| 96 |
+
|
| 97 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
| 98 |
+
if gan is True:
|
| 99 |
+
override_dict.pop('hift')
|
| 100 |
+
with open(args.config, 'r') as f:
|
| 101 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
| 102 |
+
if gan is True:
|
| 103 |
+
configs['train_conf'] = configs['train_conf_gan']
|
| 104 |
+
configs['train_conf'].update(vars(args))
|
| 105 |
+
|
| 106 |
+
# Init env for ddp
|
| 107 |
+
init_distributed(args)
|
| 108 |
+
|
| 109 |
+
# Get dataset & dataloader
|
| 110 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
| 111 |
+
init_dataset_and_dataloader(args, configs, gan)
|
| 112 |
+
|
| 113 |
+
# Do some sanity checks and save config to arsg.model_dir
|
| 114 |
+
configs = check_modify_and_save_config(args, configs)
|
| 115 |
+
|
| 116 |
+
# Tensorboard summary
|
| 117 |
+
writer = init_summarywriter(args)
|
| 118 |
+
|
| 119 |
+
# load checkpoint
|
| 120 |
+
model = configs[args.model]
|
| 121 |
+
start_step, start_epoch = 0, -1
|
| 122 |
+
if args.checkpoint is not None:
|
| 123 |
+
if os.path.exists(args.checkpoint):
|
| 124 |
+
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
| 125 |
+
model.load_state_dict(state_dict, strict=False)
|
| 126 |
+
if 'step' in state_dict:
|
| 127 |
+
start_step = state_dict['step']
|
| 128 |
+
if 'epoch' in state_dict:
|
| 129 |
+
start_epoch = state_dict['epoch']
|
| 130 |
+
else:
|
| 131 |
+
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
| 132 |
+
|
| 133 |
+
# Dispatch model from cpu to gpu
|
| 134 |
+
model = wrap_cuda_model(args, model)
|
| 135 |
+
|
| 136 |
+
# Get optimizer & scheduler
|
| 137 |
+
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
| 138 |
+
scheduler.set_step(start_step)
|
| 139 |
+
if scheduler_d is not None:
|
| 140 |
+
scheduler_d.set_step(start_step)
|
| 141 |
+
|
| 142 |
+
# Save init checkpoints
|
| 143 |
+
info_dict = deepcopy(configs['train_conf'])
|
| 144 |
+
info_dict['step'] = start_step
|
| 145 |
+
info_dict['epoch'] = start_epoch
|
| 146 |
+
save_model(model, 'init', info_dict)
|
| 147 |
+
|
| 148 |
+
# Get executor
|
| 149 |
+
executor = Executor(gan=gan)
|
| 150 |
+
executor.step = start_step
|
| 151 |
+
|
| 152 |
+
# Init scaler, used for pytorch amp mixed precision training
|
| 153 |
+
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
| 154 |
+
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
| 155 |
+
# Start training loop
|
| 156 |
+
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
| 157 |
+
executor.epoch = epoch
|
| 158 |
+
train_dataset.set_epoch(epoch)
|
| 159 |
+
dist.barrier()
|
| 160 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
| 161 |
+
if gan is True:
|
| 162 |
+
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
| 163 |
+
writer, info_dict, scaler, group_join)
|
| 164 |
+
else:
|
| 165 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
| 166 |
+
dist.destroy_process_group(group_join)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == '__main__':
|
| 170 |
+
main()
|
cosyvoice/cli/__init__.py
ADDED
|
File without changes
|
cosyvoice/cli/cosyvoice.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from hyperpyyaml import load_hyperpyyaml
|
| 18 |
+
from modelscope import snapshot_download
|
| 19 |
+
import torch
|
| 20 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
| 21 |
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
| 22 |
+
from cosyvoice.utils.file_utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CosyVoice:
|
| 26 |
+
|
| 27 |
+
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
| 28 |
+
instruct = True if '-Instruct' in model_dir else False
|
| 29 |
+
self.model_dir = model_dir
|
| 30 |
+
if not os.path.exists(model_dir):
|
| 31 |
+
model_dir = snapshot_download(model_dir)
|
| 32 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 33 |
+
configs = load_hyperpyyaml(f)
|
| 34 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 35 |
+
configs['feat_extractor'],
|
| 36 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 37 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
| 38 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 39 |
+
instruct,
|
| 40 |
+
configs['allowed_special'])
|
| 41 |
+
self.sample_rate = configs['sample_rate']
|
| 42 |
+
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
| 43 |
+
load_jit = False
|
| 44 |
+
fp16 = False
|
| 45 |
+
logging.warning('cpu do not support fp16 and jit, force set to False')
|
| 46 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
| 47 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 48 |
+
'{}/flow.pt'.format(model_dir),
|
| 49 |
+
'{}/hift.pt'.format(model_dir))
|
| 50 |
+
if load_jit:
|
| 51 |
+
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
| 52 |
+
'{}/llm.llm.fp16.zip'.format(model_dir),
|
| 53 |
+
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
| 54 |
+
if load_onnx:
|
| 55 |
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
| 56 |
+
del configs
|
| 57 |
+
|
| 58 |
+
def list_avaliable_spks(self):
|
| 59 |
+
spks = list(self.frontend.spk2info.keys())
|
| 60 |
+
return spks
|
| 61 |
+
|
| 62 |
+
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
|
| 63 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 64 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
logging.info('synthesis text {}'.format(i))
|
| 67 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 68 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 69 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 70 |
+
yield model_output
|
| 71 |
+
start_time = time.time()
|
| 72 |
+
|
| 73 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
|
| 74 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
| 75 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 76 |
+
if len(i) < 0.5 * len(prompt_text):
|
| 77 |
+
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
| 78 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
| 79 |
+
start_time = time.time()
|
| 80 |
+
logging.info('synthesis text {}'.format(i))
|
| 81 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 82 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 83 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 84 |
+
yield model_output
|
| 85 |
+
start_time = time.time()
|
| 86 |
+
|
| 87 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
|
| 88 |
+
if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
|
| 89 |
+
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
| 90 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 91 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
logging.info('synthesis text {}'.format(i))
|
| 94 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 95 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 96 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 97 |
+
yield model_output
|
| 98 |
+
start_time = time.time()
|
| 99 |
+
|
| 100 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
|
| 101 |
+
assert isinstance(self.model, CosyVoiceModel)
|
| 102 |
+
if self.frontend.instruct is False:
|
| 103 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
| 104 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
| 105 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 106 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
| 107 |
+
start_time = time.time()
|
| 108 |
+
logging.info('synthesis text {}'.format(i))
|
| 109 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 110 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 111 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 112 |
+
yield model_output
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
|
| 115 |
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
|
| 116 |
+
assert isinstance(self.model, CosyVoice2Model)
|
| 117 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
| 118 |
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
logging.info('synthesis text {}'.format(i))
|
| 121 |
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
| 122 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 123 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 124 |
+
yield model_output
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
|
| 127 |
+
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
| 128 |
+
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
| 129 |
+
start_time = time.time()
|
| 130 |
+
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
| 131 |
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
| 132 |
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
| 133 |
+
yield model_output
|
| 134 |
+
start_time = time.time()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class CosyVoice2(CosyVoice):
|
| 138 |
+
|
| 139 |
+
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
| 140 |
+
instruct = True if '-Instruct' in model_dir else False
|
| 141 |
+
self.model_dir = model_dir
|
| 142 |
+
if not os.path.exists(model_dir):
|
| 143 |
+
model_dir = snapshot_download(model_dir)
|
| 144 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 145 |
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 146 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 147 |
+
configs['feat_extractor'],
|
| 148 |
+
'{}/campplus.onnx'.format(model_dir),
|
| 149 |
+
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
| 150 |
+
'{}/spk2info.pt'.format(model_dir),
|
| 151 |
+
instruct,
|
| 152 |
+
configs['allowed_special'])
|
| 153 |
+
self.sample_rate = configs['sample_rate']
|
| 154 |
+
if torch.cuda.is_available() is False and load_jit is True:
|
| 155 |
+
load_jit = False
|
| 156 |
+
logging.warning('cpu do not support jit, force set to False')
|
| 157 |
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
| 158 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
| 159 |
+
'{}/flow.pt'.format(model_dir),
|
| 160 |
+
'{}/hift.pt'.format(model_dir))
|
| 161 |
+
if load_jit:
|
| 162 |
+
self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
|
| 163 |
+
if load_trt is True and load_onnx is True:
|
| 164 |
+
load_onnx = False
|
| 165 |
+
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
| 166 |
+
if load_onnx:
|
| 167 |
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
| 168 |
+
if load_trt:
|
| 169 |
+
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
|
| 170 |
+
del configs
|
cosyvoice/cli/frontend.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from functools import partial
|
| 15 |
+
import json
|
| 16 |
+
import onnxruntime
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import whisper
|
| 20 |
+
from typing import Callable
|
| 21 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 22 |
+
import torchaudio
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
import inflect
|
| 26 |
+
try:
|
| 27 |
+
import ttsfrd
|
| 28 |
+
use_ttsfrd = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
| 31 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
| 32 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
| 33 |
+
use_ttsfrd = False
|
| 34 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CosyVoiceFrontEnd:
|
| 38 |
+
|
| 39 |
+
def __init__(self,
|
| 40 |
+
get_tokenizer: Callable,
|
| 41 |
+
feat_extractor: Callable,
|
| 42 |
+
campplus_model: str,
|
| 43 |
+
speech_tokenizer_model: str,
|
| 44 |
+
spk2info: str = '',
|
| 45 |
+
instruct: bool = False,
|
| 46 |
+
allowed_special: str = 'all'):
|
| 47 |
+
self.tokenizer = get_tokenizer()
|
| 48 |
+
self.feat_extractor = feat_extractor
|
| 49 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 50 |
+
option = onnxruntime.SessionOptions()
|
| 51 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 52 |
+
option.intra_op_num_threads = 1
|
| 53 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 54 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 55 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 56 |
+
"CPUExecutionProvider"])
|
| 57 |
+
if os.path.exists(spk2info):
|
| 58 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 59 |
+
else:
|
| 60 |
+
self.spk2info = {}
|
| 61 |
+
self.instruct = instruct
|
| 62 |
+
self.allowed_special = allowed_special
|
| 63 |
+
self.inflect_parser = inflect.engine()
|
| 64 |
+
self.use_ttsfrd = use_ttsfrd
|
| 65 |
+
if self.use_ttsfrd:
|
| 66 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
| 67 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 68 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
| 69 |
+
'failed to initialize ttsfrd resource'
|
| 70 |
+
self.frd.set_lang_type('pinyinvg')
|
| 71 |
+
else:
|
| 72 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
| 73 |
+
self.en_tn_model = EnNormalizer()
|
| 74 |
+
|
| 75 |
+
def _extract_text_token(self, text):
|
| 76 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
| 77 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
| 78 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 79 |
+
return text_token, text_token_len
|
| 80 |
+
|
| 81 |
+
def _extract_speech_token(self, speech):
|
| 82 |
+
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
| 83 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
| 84 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
| 85 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
| 86 |
+
feat.detach().cpu().numpy(),
|
| 87 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
| 88 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
| 89 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
| 90 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
| 91 |
+
return speech_token, speech_token_len
|
| 92 |
+
|
| 93 |
+
def _extract_spk_embedding(self, speech):
|
| 94 |
+
feat = kaldi.fbank(speech,
|
| 95 |
+
num_mel_bins=80,
|
| 96 |
+
dither=0,
|
| 97 |
+
sample_frequency=16000)
|
| 98 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 99 |
+
embedding = self.campplus_session.run(None,
|
| 100 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
| 101 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
| 102 |
+
return embedding
|
| 103 |
+
|
| 104 |
+
def _extract_speech_feat(self, speech):
|
| 105 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
| 106 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
| 107 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
| 108 |
+
return speech_feat, speech_feat_len
|
| 109 |
+
|
| 110 |
+
def text_normalize(self, text, split=True):
|
| 111 |
+
text = text.strip()
|
| 112 |
+
# NOTE(lyuxiang.lx) move this judgement into ttsfrd in the future
|
| 113 |
+
for token in self.tokenizer.special_tokens['additional_special_tokens']:
|
| 114 |
+
if token in text:
|
| 115 |
+
return text if split is False else [text]
|
| 116 |
+
if contains_chinese(text):
|
| 117 |
+
if self.use_ttsfrd:
|
| 118 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 119 |
+
text = ''.join(texts)
|
| 120 |
+
else:
|
| 121 |
+
text = self.zh_tn_model.normalize(text)
|
| 122 |
+
text = text.replace("\n", "")
|
| 123 |
+
text = replace_blank(text)
|
| 124 |
+
text = replace_corner_mark(text)
|
| 125 |
+
text = text.replace(".", "。")
|
| 126 |
+
text = text.replace(" - ", ",")
|
| 127 |
+
text = remove_bracket(text)
|
| 128 |
+
text = re.sub(r'[,,、]+$', '。', text)
|
| 129 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
| 130 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 131 |
+
else:
|
| 132 |
+
if self.use_ttsfrd:
|
| 133 |
+
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
| 134 |
+
text = ''.join(texts)
|
| 135 |
+
else:
|
| 136 |
+
text = self.en_tn_model.normalize(text)
|
| 137 |
+
text = spell_out_number(text, self.inflect_parser)
|
| 138 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
| 139 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
| 140 |
+
if split is False:
|
| 141 |
+
return text
|
| 142 |
+
return texts
|
| 143 |
+
|
| 144 |
+
def frontend_sft(self, tts_text, spk_id):
|
| 145 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 146 |
+
embedding = self.spk2info[spk_id]['embedding']
|
| 147 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 148 |
+
return model_input
|
| 149 |
+
|
| 150 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
|
| 151 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 152 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
| 153 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 154 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 155 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 156 |
+
if resample_rate == 24000:
|
| 157 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 158 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 159 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 160 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 161 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 162 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 163 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 164 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
| 165 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 166 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 167 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 168 |
+
return model_input
|
| 169 |
+
|
| 170 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
|
| 171 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
|
| 172 |
+
# in cross lingual mode, we remove prompt in llm
|
| 173 |
+
del model_input['prompt_text']
|
| 174 |
+
del model_input['prompt_text_len']
|
| 175 |
+
del model_input['llm_prompt_speech_token']
|
| 176 |
+
del model_input['llm_prompt_speech_token_len']
|
| 177 |
+
return model_input
|
| 178 |
+
|
| 179 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
| 180 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
| 181 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
| 182 |
+
del model_input['llm_embedding']
|
| 183 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
| 184 |
+
model_input['prompt_text'] = instruct_text_token
|
| 185 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
| 186 |
+
return model_input
|
| 187 |
+
|
| 188 |
+
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
|
| 189 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
| 190 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
|
| 191 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 192 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 193 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 194 |
+
if resample_rate == 24000:
|
| 195 |
+
# cosyvoice2, force speech_feat % speech_token = 2
|
| 196 |
+
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
| 197 |
+
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
| 198 |
+
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
| 199 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 200 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
| 201 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
| 202 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
| 203 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
| 204 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
| 205 |
+
return model_input
|
| 206 |
+
|
| 207 |
+
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
|
| 208 |
+
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
| 209 |
+
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
|
| 210 |
+
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
|
| 211 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
| 212 |
+
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
| 213 |
+
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
| 214 |
+
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
| 215 |
+
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
| 216 |
+
'flow_embedding': embedding}
|
| 217 |
+
return model_input
|
cosyvoice/cli/model.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import threading
|
| 17 |
+
import time
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
from contextlib import nullcontext
|
| 20 |
+
import uuid
|
| 21 |
+
from cosyvoice.utils.common import fade_in_out
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CosyVoiceModel:
|
| 25 |
+
|
| 26 |
+
def __init__(self,
|
| 27 |
+
llm: torch.nn.Module,
|
| 28 |
+
flow: torch.nn.Module,
|
| 29 |
+
hift: torch.nn.Module,
|
| 30 |
+
fp16: bool):
|
| 31 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
self.llm = llm
|
| 33 |
+
self.flow = flow
|
| 34 |
+
self.hift = hift
|
| 35 |
+
self.fp16 = fp16
|
| 36 |
+
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
| 37 |
+
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
| 38 |
+
self.token_overlap_len = 20
|
| 39 |
+
# mel fade in out
|
| 40 |
+
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
| 41 |
+
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
| 42 |
+
# hift cache
|
| 43 |
+
self.mel_cache_len = 20
|
| 44 |
+
self.source_cache_len = int(self.mel_cache_len * 256)
|
| 45 |
+
# speech fade in out
|
| 46 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 47 |
+
# rtf and decoding related
|
| 48 |
+
self.stream_scale_factor = 1
|
| 49 |
+
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
| 50 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 51 |
+
self.lock = threading.Lock()
|
| 52 |
+
# dict used to store session related variable
|
| 53 |
+
self.tts_speech_token_dict = {}
|
| 54 |
+
self.llm_end_dict = {}
|
| 55 |
+
self.mel_overlap_dict = {}
|
| 56 |
+
self.flow_cache_dict = {}
|
| 57 |
+
self.hift_cache_dict = {}
|
| 58 |
+
|
| 59 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 60 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
| 61 |
+
self.llm.to(self.device).eval()
|
| 62 |
+
if self.fp16 is True:
|
| 63 |
+
self.llm.half()
|
| 64 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
| 65 |
+
self.flow.to(self.device).eval()
|
| 66 |
+
# in case hift_model is a hifigan model
|
| 67 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
| 68 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 69 |
+
self.hift.to(self.device).eval()
|
| 70 |
+
|
| 71 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
| 72 |
+
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
| 73 |
+
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
| 74 |
+
self.llm.text_encoder = llm_text_encoder
|
| 75 |
+
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
| 76 |
+
self.llm.llm = llm_llm
|
| 77 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 78 |
+
self.flow.encoder = flow_encoder
|
| 79 |
+
|
| 80 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
| 81 |
+
import onnxruntime
|
| 82 |
+
option = onnxruntime.SessionOptions()
|
| 83 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 84 |
+
option.intra_op_num_threads = 1
|
| 85 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
| 86 |
+
del self.flow.decoder.estimator
|
| 87 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
| 88 |
+
|
| 89 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
| 90 |
+
if self.fp16 is True:
|
| 91 |
+
llm_embedding = llm_embedding.half()
|
| 92 |
+
with self.llm_context:
|
| 93 |
+
for i in self.llm.inference(text=text.to(self.device),
|
| 94 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
| 95 |
+
prompt_text=prompt_text.to(self.device),
|
| 96 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 97 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 98 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 99 |
+
embedding=llm_embedding.to(self.device)):
|
| 100 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 101 |
+
self.llm_end_dict[uuid] = True
|
| 102 |
+
|
| 103 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
| 104 |
+
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
| 105 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 106 |
+
prompt_token=prompt_token.to(self.device),
|
| 107 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 108 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 109 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 110 |
+
embedding=embedding.to(self.device),
|
| 111 |
+
flow_cache=self.flow_cache_dict[uuid])
|
| 112 |
+
self.flow_cache_dict[uuid] = flow_cache
|
| 113 |
+
|
| 114 |
+
# mel overlap fade in out
|
| 115 |
+
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
| 116 |
+
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
| 117 |
+
# append hift cache
|
| 118 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 119 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 120 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 121 |
+
else:
|
| 122 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 123 |
+
# keep overlap mel and hift cache
|
| 124 |
+
if finalize is False:
|
| 125 |
+
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
| 126 |
+
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
| 127 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 128 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 129 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 130 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 131 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 132 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 133 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 134 |
+
else:
|
| 135 |
+
if speed != 1.0:
|
| 136 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 137 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 138 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 139 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 140 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 141 |
+
return tts_speech
|
| 142 |
+
|
| 143 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 144 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 145 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 146 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 147 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 148 |
+
# this_uuid is used to track variables related to this inference thread
|
| 149 |
+
this_uuid = str(uuid.uuid1())
|
| 150 |
+
with self.lock:
|
| 151 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 152 |
+
self.hift_cache_dict[this_uuid] = None
|
| 153 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 154 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 155 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 156 |
+
p.start()
|
| 157 |
+
if stream is True:
|
| 158 |
+
token_hop_len = self.token_min_hop_len
|
| 159 |
+
while True:
|
| 160 |
+
time.sleep(0.1)
|
| 161 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 162 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 163 |
+
.unsqueeze(dim=0)
|
| 164 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 165 |
+
prompt_token=flow_prompt_speech_token,
|
| 166 |
+
prompt_feat=prompt_speech_feat,
|
| 167 |
+
embedding=flow_embedding,
|
| 168 |
+
uuid=this_uuid,
|
| 169 |
+
finalize=False)
|
| 170 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 171 |
+
with self.lock:
|
| 172 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 173 |
+
# increase token_hop_len for better speech quality
|
| 174 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 175 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 176 |
+
break
|
| 177 |
+
p.join()
|
| 178 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 179 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 180 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 181 |
+
prompt_token=flow_prompt_speech_token,
|
| 182 |
+
prompt_feat=prompt_speech_feat,
|
| 183 |
+
embedding=flow_embedding,
|
| 184 |
+
uuid=this_uuid,
|
| 185 |
+
finalize=True)
|
| 186 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 187 |
+
else:
|
| 188 |
+
# deal with all tokens
|
| 189 |
+
p.join()
|
| 190 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 191 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 192 |
+
prompt_token=flow_prompt_speech_token,
|
| 193 |
+
prompt_feat=prompt_speech_feat,
|
| 194 |
+
embedding=flow_embedding,
|
| 195 |
+
uuid=this_uuid,
|
| 196 |
+
finalize=True,
|
| 197 |
+
speed=speed)
|
| 198 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 199 |
+
with self.lock:
|
| 200 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 201 |
+
self.llm_end_dict.pop(this_uuid)
|
| 202 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 203 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 204 |
+
self.flow_cache_dict.pop(this_uuid)
|
| 205 |
+
|
| 206 |
+
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
| 207 |
+
# this_uuid is used to track variables related to this inference thread
|
| 208 |
+
this_uuid = str(uuid.uuid1())
|
| 209 |
+
with self.lock:
|
| 210 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
| 211 |
+
self.hift_cache_dict[this_uuid] = None
|
| 212 |
+
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
| 213 |
+
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
| 214 |
+
if stream is True:
|
| 215 |
+
token_hop_len = self.token_min_hop_len
|
| 216 |
+
while True:
|
| 217 |
+
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
| 218 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
| 219 |
+
.unsqueeze(dim=0)
|
| 220 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 221 |
+
prompt_token=flow_prompt_speech_token,
|
| 222 |
+
prompt_feat=prompt_speech_feat,
|
| 223 |
+
embedding=flow_embedding,
|
| 224 |
+
uuid=this_uuid,
|
| 225 |
+
finalize=False)
|
| 226 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 227 |
+
with self.lock:
|
| 228 |
+
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
| 229 |
+
# increase token_hop_len for better speech quality
|
| 230 |
+
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
| 231 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
| 232 |
+
break
|
| 233 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 234 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 235 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 236 |
+
prompt_token=flow_prompt_speech_token,
|
| 237 |
+
prompt_feat=prompt_speech_feat,
|
| 238 |
+
embedding=flow_embedding,
|
| 239 |
+
uuid=this_uuid,
|
| 240 |
+
finalize=True)
|
| 241 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 242 |
+
else:
|
| 243 |
+
# deal with all tokens
|
| 244 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 245 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 246 |
+
prompt_token=flow_prompt_speech_token,
|
| 247 |
+
prompt_feat=prompt_speech_feat,
|
| 248 |
+
embedding=flow_embedding,
|
| 249 |
+
uuid=this_uuid,
|
| 250 |
+
finalize=True,
|
| 251 |
+
speed=speed)
|
| 252 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 253 |
+
with self.lock:
|
| 254 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 255 |
+
self.llm_end_dict.pop(this_uuid)
|
| 256 |
+
self.mel_overlap_dict.pop(this_uuid)
|
| 257 |
+
self.hift_cache_dict.pop(this_uuid)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class CosyVoice2Model:
|
| 261 |
+
|
| 262 |
+
def __init__(self,
|
| 263 |
+
llm: torch.nn.Module,
|
| 264 |
+
flow: torch.nn.Module,
|
| 265 |
+
hift: torch.nn.Module):
|
| 266 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 267 |
+
self.llm = llm
|
| 268 |
+
self.flow = flow
|
| 269 |
+
self.hift = hift
|
| 270 |
+
self.token_hop_len = 2 * self.flow.input_frame_rate
|
| 271 |
+
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
| 272 |
+
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
| 273 |
+
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
|
| 274 |
+
# hift cache
|
| 275 |
+
self.mel_cache_len = 8
|
| 276 |
+
self.source_cache_len = int(self.mel_cache_len * 480)
|
| 277 |
+
# speech fade in out
|
| 278 |
+
self.speech_window = np.hamming(2 * self.source_cache_len)
|
| 279 |
+
# rtf and decoding related
|
| 280 |
+
self.stream_scale_factor = 1
|
| 281 |
+
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
| 282 |
+
self.lock = threading.Lock()
|
| 283 |
+
# dict used to store session related variable
|
| 284 |
+
self.tts_speech_token_dict = {}
|
| 285 |
+
self.llm_end_dict = {}
|
| 286 |
+
self.hift_cache_dict = {}
|
| 287 |
+
|
| 288 |
+
def load(self, llm_model, flow_model, hift_model):
|
| 289 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
| 290 |
+
self.llm.to(self.device).eval()
|
| 291 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
| 292 |
+
self.flow.to(self.device).eval()
|
| 293 |
+
self.flow.decoder.fp16 = False
|
| 294 |
+
# in case hift_model is a hifigan model
|
| 295 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
| 296 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 297 |
+
self.hift.to(self.device).eval()
|
| 298 |
+
|
| 299 |
+
def load_jit(self, flow_encoder_model):
|
| 300 |
+
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
| 301 |
+
self.flow.encoder = flow_encoder
|
| 302 |
+
|
| 303 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
| 304 |
+
import onnxruntime
|
| 305 |
+
option = onnxruntime.SessionOptions()
|
| 306 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 307 |
+
option.intra_op_num_threads = 1
|
| 308 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
| 309 |
+
del self.flow.decoder.estimator
|
| 310 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
| 311 |
+
|
| 312 |
+
def load_trt(self, flow_decoder_estimator_model):
|
| 313 |
+
del self.flow.decoder.estimator
|
| 314 |
+
import tensorrt as trt
|
| 315 |
+
with open(flow_decoder_estimator_model, 'rb') as f:
|
| 316 |
+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
| 317 |
+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
| 318 |
+
self.flow.decoder.fp16 = True
|
| 319 |
+
|
| 320 |
+
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
| 321 |
+
with self.llm_context:
|
| 322 |
+
for i in self.llm.inference(text=text.to(self.device),
|
| 323 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
| 324 |
+
prompt_text=prompt_text.to(self.device),
|
| 325 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
| 326 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
| 327 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 328 |
+
embedding=llm_embedding.to(self.device)):
|
| 329 |
+
self.tts_speech_token_dict[uuid].append(i)
|
| 330 |
+
self.llm_end_dict[uuid] = True
|
| 331 |
+
|
| 332 |
+
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
| 333 |
+
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
| 334 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
| 335 |
+
prompt_token=prompt_token.to(self.device),
|
| 336 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
| 337 |
+
prompt_feat=prompt_feat.to(self.device),
|
| 338 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
| 339 |
+
embedding=embedding.to(self.device),
|
| 340 |
+
finalize=finalize)
|
| 341 |
+
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
| 342 |
+
# append hift cache
|
| 343 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 344 |
+
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
| 345 |
+
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
| 346 |
+
else:
|
| 347 |
+
hift_cache_source = torch.zeros(1, 1, 0)
|
| 348 |
+
# keep overlap mel and hift cache
|
| 349 |
+
if finalize is False:
|
| 350 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 351 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 352 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 353 |
+
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
| 354 |
+
'source': tts_source[:, :, -self.source_cache_len:],
|
| 355 |
+
'speech': tts_speech[:, -self.source_cache_len:]}
|
| 356 |
+
tts_speech = tts_speech[:, :-self.source_cache_len]
|
| 357 |
+
else:
|
| 358 |
+
if speed != 1.0:
|
| 359 |
+
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
| 360 |
+
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
| 361 |
+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
| 362 |
+
if self.hift_cache_dict[uuid] is not None:
|
| 363 |
+
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
| 364 |
+
return tts_speech
|
| 365 |
+
|
| 366 |
+
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
| 367 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
| 368 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 369 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
| 370 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
| 371 |
+
# this_uuid is used to track variables related to this inference thread
|
| 372 |
+
this_uuid = str(uuid.uuid1())
|
| 373 |
+
with self.lock:
|
| 374 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
| 375 |
+
self.hift_cache_dict[this_uuid] = None
|
| 376 |
+
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
| 377 |
+
p.start()
|
| 378 |
+
if stream is True:
|
| 379 |
+
token_offset = 0
|
| 380 |
+
while True:
|
| 381 |
+
time.sleep(0.1)
|
| 382 |
+
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
|
| 383 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
| 384 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 385 |
+
prompt_token=flow_prompt_speech_token,
|
| 386 |
+
prompt_feat=prompt_speech_feat,
|
| 387 |
+
embedding=flow_embedding,
|
| 388 |
+
uuid=this_uuid,
|
| 389 |
+
token_offset=token_offset,
|
| 390 |
+
finalize=False)
|
| 391 |
+
token_offset += self.token_hop_len
|
| 392 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 393 |
+
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
|
| 394 |
+
break
|
| 395 |
+
p.join()
|
| 396 |
+
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
| 397 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 398 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 399 |
+
prompt_token=flow_prompt_speech_token,
|
| 400 |
+
prompt_feat=prompt_speech_feat,
|
| 401 |
+
embedding=flow_embedding,
|
| 402 |
+
uuid=this_uuid,
|
| 403 |
+
token_offset=token_offset,
|
| 404 |
+
finalize=True)
|
| 405 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 406 |
+
else:
|
| 407 |
+
# deal with all tokens
|
| 408 |
+
p.join()
|
| 409 |
+
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
| 410 |
+
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
| 411 |
+
prompt_token=flow_prompt_speech_token,
|
| 412 |
+
prompt_feat=prompt_speech_feat,
|
| 413 |
+
embedding=flow_embedding,
|
| 414 |
+
uuid=this_uuid,
|
| 415 |
+
token_offset=0,
|
| 416 |
+
finalize=True,
|
| 417 |
+
speed=speed)
|
| 418 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
| 419 |
+
with self.lock:
|
| 420 |
+
self.tts_speech_token_dict.pop(this_uuid)
|
| 421 |
+
self.llm_end_dict.pop(this_uuid)
|
cosyvoice/dataset/__init__.py
ADDED
|
File without changes
|
cosyvoice/dataset/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
| 2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.utils.data import IterableDataset
|
| 24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Processor(IterableDataset):
|
| 28 |
+
|
| 29 |
+
def __init__(self, source, f, *args, **kw):
|
| 30 |
+
assert callable(f)
|
| 31 |
+
self.source = source
|
| 32 |
+
self.f = f
|
| 33 |
+
self.args = args
|
| 34 |
+
self.kw = kw
|
| 35 |
+
|
| 36 |
+
def set_epoch(self, epoch):
|
| 37 |
+
self.source.set_epoch(epoch)
|
| 38 |
+
|
| 39 |
+
def __iter__(self):
|
| 40 |
+
""" Return an iterator over the source dataset processed by the
|
| 41 |
+
given processor.
|
| 42 |
+
"""
|
| 43 |
+
assert self.source is not None
|
| 44 |
+
assert callable(self.f)
|
| 45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
| 46 |
+
|
| 47 |
+
def apply(self, f):
|
| 48 |
+
assert callable(f)
|
| 49 |
+
return Processor(self, f, *self.args, **self.kw)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DistributedSampler:
|
| 53 |
+
|
| 54 |
+
def __init__(self, shuffle=True, partition=True):
|
| 55 |
+
self.epoch = -1
|
| 56 |
+
self.update()
|
| 57 |
+
self.shuffle = shuffle
|
| 58 |
+
self.partition = partition
|
| 59 |
+
|
| 60 |
+
def update(self):
|
| 61 |
+
assert dist.is_available()
|
| 62 |
+
if dist.is_initialized():
|
| 63 |
+
self.rank = dist.get_rank()
|
| 64 |
+
self.world_size = dist.get_world_size()
|
| 65 |
+
else:
|
| 66 |
+
self.rank = 0
|
| 67 |
+
self.world_size = 1
|
| 68 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 69 |
+
if worker_info is None:
|
| 70 |
+
self.worker_id = 0
|
| 71 |
+
self.num_workers = 1
|
| 72 |
+
else:
|
| 73 |
+
self.worker_id = worker_info.id
|
| 74 |
+
self.num_workers = worker_info.num_workers
|
| 75 |
+
return dict(rank=self.rank,
|
| 76 |
+
world_size=self.world_size,
|
| 77 |
+
worker_id=self.worker_id,
|
| 78 |
+
num_workers=self.num_workers)
|
| 79 |
+
|
| 80 |
+
def set_epoch(self, epoch):
|
| 81 |
+
self.epoch = epoch
|
| 82 |
+
|
| 83 |
+
def sample(self, data):
|
| 84 |
+
""" Sample data according to rank/world_size/num_workers
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data(List): input data list
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List: data list after sample
|
| 91 |
+
"""
|
| 92 |
+
data = list(range(len(data)))
|
| 93 |
+
# force datalist even
|
| 94 |
+
if self.partition:
|
| 95 |
+
if self.shuffle:
|
| 96 |
+
random.Random(self.epoch).shuffle(data)
|
| 97 |
+
if len(data) < self.world_size:
|
| 98 |
+
data = data * math.ceil(self.world_size / len(data))
|
| 99 |
+
data = data[:self.world_size]
|
| 100 |
+
data = data[self.rank::self.world_size]
|
| 101 |
+
if len(data) < self.num_workers:
|
| 102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
| 103 |
+
data = data[:self.num_workers]
|
| 104 |
+
data = data[self.worker_id::self.num_workers]
|
| 105 |
+
return data
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DataList(IterableDataset):
|
| 109 |
+
|
| 110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
| 111 |
+
self.lists = lists
|
| 112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
| 113 |
+
|
| 114 |
+
def set_epoch(self, epoch):
|
| 115 |
+
self.sampler.set_epoch(epoch)
|
| 116 |
+
|
| 117 |
+
def __iter__(self):
|
| 118 |
+
sampler_info = self.sampler.update()
|
| 119 |
+
indexes = self.sampler.sample(self.lists)
|
| 120 |
+
for index in indexes:
|
| 121 |
+
data = dict(src=self.lists[index])
|
| 122 |
+
data.update(sampler_info)
|
| 123 |
+
yield data
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def Dataset(data_list_file,
|
| 127 |
+
data_pipeline,
|
| 128 |
+
mode='train',
|
| 129 |
+
gan=False,
|
| 130 |
+
shuffle=True,
|
| 131 |
+
partition=True,
|
| 132 |
+
tts_file='',
|
| 133 |
+
prompt_utt2data=''):
|
| 134 |
+
""" Construct dataset from arguments
|
| 135 |
+
|
| 136 |
+
We have two shuffle stage in the Dataset. The first is global
|
| 137 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
| 138 |
+
at training samples level.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
data_type(str): raw/shard
|
| 142 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
| 143 |
+
partition(bool): whether to do data partition in terms of rank
|
| 144 |
+
"""
|
| 145 |
+
assert mode in ['train', 'inference']
|
| 146 |
+
lists = read_lists(data_list_file)
|
| 147 |
+
if mode == 'inference':
|
| 148 |
+
with open(tts_file) as f:
|
| 149 |
+
tts_data = json.load(f)
|
| 150 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
| 151 |
+
# filter unnecessary file in inference mode
|
| 152 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
| 153 |
+
dataset = DataList(lists,
|
| 154 |
+
shuffle=shuffle,
|
| 155 |
+
partition=partition)
|
| 156 |
+
if mode == 'inference':
|
| 157 |
+
# map partial arg to parquet_opener func in inference mode
|
| 158 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
| 159 |
+
if gan is True:
|
| 160 |
+
# map partial arg to padding func in gan mode
|
| 161 |
+
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
| 162 |
+
for func in data_pipeline:
|
| 163 |
+
dataset = Processor(dataset, func, mode=mode)
|
| 164 |
+
return dataset
|
cosyvoice/dataset/processor.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
torchaudio.set_audio_backend('soundfile')
|
| 25 |
+
|
| 26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
| 30 |
+
""" Give url or local file, return file descriptor
|
| 31 |
+
Inplace operation.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data(Iterable[str]): url or local file list
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Iterable[{src, stream}]
|
| 38 |
+
"""
|
| 39 |
+
for sample in data:
|
| 40 |
+
assert 'src' in sample
|
| 41 |
+
url = sample['src']
|
| 42 |
+
try:
|
| 43 |
+
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
| 44 |
+
df = df.to_pandas()
|
| 45 |
+
for i in range(len(df)):
|
| 46 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
| 47 |
+
continue
|
| 48 |
+
sample.update(dict(df.loc[i]))
|
| 49 |
+
if mode == 'train':
|
| 50 |
+
# NOTE do not return sample directly, must initialize a new dict
|
| 51 |
+
yield {**sample}
|
| 52 |
+
else:
|
| 53 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
| 54 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
| 55 |
+
except Exception as ex:
|
| 56 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def filter(data,
|
| 60 |
+
max_length=10240,
|
| 61 |
+
min_length=10,
|
| 62 |
+
token_max_length=200,
|
| 63 |
+
token_min_length=1,
|
| 64 |
+
min_output_input_ratio=0.0005,
|
| 65 |
+
max_output_input_ratio=1,
|
| 66 |
+
mode='train'):
|
| 67 |
+
""" Filter sample according to feature and label length
|
| 68 |
+
Inplace operation.
|
| 69 |
+
|
| 70 |
+
Args::
|
| 71 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 72 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
| 73 |
+
min_length: drop utterance which is less than min_length(10ms)
|
| 74 |
+
token_max_length: drop utterance which is greater than
|
| 75 |
+
token_max_length, especially when use char unit for
|
| 76 |
+
english modeling
|
| 77 |
+
token_min_length: drop utterance which is
|
| 78 |
+
less than token_max_length
|
| 79 |
+
min_output_input_ratio: minimal ration of
|
| 80 |
+
token_length / feats_length(10ms)
|
| 81 |
+
max_output_input_ratio: maximum ration of
|
| 82 |
+
token_length / feats_length(10ms)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 86 |
+
"""
|
| 87 |
+
for sample in data:
|
| 88 |
+
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
| 89 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
| 90 |
+
del sample['audio_data']
|
| 91 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
| 92 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
| 93 |
+
if num_frames < min_length:
|
| 94 |
+
continue
|
| 95 |
+
if num_frames > max_length:
|
| 96 |
+
continue
|
| 97 |
+
if len(sample['text_token']) < token_min_length:
|
| 98 |
+
continue
|
| 99 |
+
if len(sample['text_token']) > token_max_length:
|
| 100 |
+
continue
|
| 101 |
+
if len(sample['speech_token']) == 0:
|
| 102 |
+
continue
|
| 103 |
+
if num_frames != 0:
|
| 104 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
| 105 |
+
continue
|
| 106 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
| 107 |
+
continue
|
| 108 |
+
yield sample
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
| 112 |
+
""" Resample data.
|
| 113 |
+
Inplace operation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 117 |
+
resample_rate: target resample rate
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 121 |
+
"""
|
| 122 |
+
for sample in data:
|
| 123 |
+
assert 'sample_rate' in sample
|
| 124 |
+
assert 'speech' in sample
|
| 125 |
+
sample_rate = sample['sample_rate']
|
| 126 |
+
waveform = sample['speech']
|
| 127 |
+
if sample_rate != resample_rate:
|
| 128 |
+
if sample_rate < min_sample_rate:
|
| 129 |
+
continue
|
| 130 |
+
sample['sample_rate'] = resample_rate
|
| 131 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
| 132 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
| 133 |
+
max_val = sample['speech'].abs().max()
|
| 134 |
+
if max_val > 1:
|
| 135 |
+
sample['speech'] /= max_val
|
| 136 |
+
yield sample
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def truncate(data, truncate_length=24576, mode='train'):
|
| 140 |
+
""" Truncate data.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 144 |
+
truncate_length: truncate length
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Iterable[{key, wav, label, sample_rate}]
|
| 148 |
+
"""
|
| 149 |
+
for sample in data:
|
| 150 |
+
waveform = sample['speech']
|
| 151 |
+
if waveform.shape[1] > truncate_length:
|
| 152 |
+
start = random.randint(0, waveform.shape[1] - truncate_length)
|
| 153 |
+
waveform = waveform[:, start: start + truncate_length]
|
| 154 |
+
else:
|
| 155 |
+
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
| 156 |
+
sample['speech'] = waveform
|
| 157 |
+
yield sample
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_fbank(data,
|
| 161 |
+
feat_extractor,
|
| 162 |
+
mode='train'):
|
| 163 |
+
""" Extract fbank
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Iterable[{key, feat, label}]
|
| 170 |
+
"""
|
| 171 |
+
for sample in data:
|
| 172 |
+
assert 'sample_rate' in sample
|
| 173 |
+
assert 'speech' in sample
|
| 174 |
+
assert 'utt' in sample
|
| 175 |
+
assert 'text_token' in sample
|
| 176 |
+
waveform = sample['speech']
|
| 177 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
| 178 |
+
sample['speech_feat'] = mat
|
| 179 |
+
yield sample
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def compute_f0(data, pitch_extractor, mode='train'):
|
| 183 |
+
""" Extract f0
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Iterable[{key, feat, label}]
|
| 190 |
+
"""
|
| 191 |
+
for sample in data:
|
| 192 |
+
assert 'sample_rate' in sample
|
| 193 |
+
assert 'speech' in sample
|
| 194 |
+
assert 'utt' in sample
|
| 195 |
+
assert 'text_token' in sample
|
| 196 |
+
waveform = sample['speech']
|
| 197 |
+
mat = pitch_extractor(waveform).transpose(1, 2)
|
| 198 |
+
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
|
| 199 |
+
sample['pitch_feat'] = mat[0, 0]
|
| 200 |
+
yield sample
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def parse_embedding(data, normalize, mode='train'):
|
| 204 |
+
""" Parse utt_embedding/spk_embedding
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Iterable[{key, feat, label}]
|
| 211 |
+
"""
|
| 212 |
+
for sample in data:
|
| 213 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
| 214 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
| 215 |
+
if normalize:
|
| 216 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
| 217 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
| 218 |
+
yield sample
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
| 222 |
+
""" Decode text to chars or BPE
|
| 223 |
+
Inplace operation
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
| 230 |
+
"""
|
| 231 |
+
tokenizer = get_tokenizer()
|
| 232 |
+
for sample in data:
|
| 233 |
+
assert 'text' in sample
|
| 234 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
| 235 |
+
if mode == 'inference':
|
| 236 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
| 237 |
+
yield sample
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
| 241 |
+
""" Local shuffle the data
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
data: Iterable[{key, feat, label}]
|
| 245 |
+
shuffle_size: buffer size for shuffle
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Iterable[{key, feat, label}]
|
| 249 |
+
"""
|
| 250 |
+
buf = []
|
| 251 |
+
for sample in data:
|
| 252 |
+
buf.append(sample)
|
| 253 |
+
if len(buf) >= shuffle_size:
|
| 254 |
+
random.shuffle(buf)
|
| 255 |
+
for x in buf:
|
| 256 |
+
yield x
|
| 257 |
+
buf = []
|
| 258 |
+
# The sample left over
|
| 259 |
+
random.shuffle(buf)
|
| 260 |
+
for x in buf:
|
| 261 |
+
yield x
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def sort(data, sort_size=500, mode='train'):
|
| 265 |
+
""" Sort the data by feature length.
|
| 266 |
+
Sort is used after shuffle and before batch, so we can group
|
| 267 |
+
utts with similar lengths into a batch, and `sort_size` should
|
| 268 |
+
be less than `shuffle_size`
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
data: Iterable[{key, feat, label}]
|
| 272 |
+
sort_size: buffer size for sort
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Iterable[{key, feat, label}]
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
buf = []
|
| 279 |
+
for sample in data:
|
| 280 |
+
buf.append(sample)
|
| 281 |
+
if len(buf) >= sort_size:
|
| 282 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 283 |
+
for x in buf:
|
| 284 |
+
yield x
|
| 285 |
+
buf = []
|
| 286 |
+
# The sample left over
|
| 287 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
| 288 |
+
for x in buf:
|
| 289 |
+
yield x
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def static_batch(data, batch_size=16):
|
| 293 |
+
""" Static batch the data by `batch_size`
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
data: Iterable[{key, feat, label}]
|
| 297 |
+
batch_size: batch size
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Iterable[List[{key, feat, label}]]
|
| 301 |
+
"""
|
| 302 |
+
buf = []
|
| 303 |
+
for sample in data:
|
| 304 |
+
buf.append(sample)
|
| 305 |
+
if len(buf) >= batch_size:
|
| 306 |
+
yield buf
|
| 307 |
+
buf = []
|
| 308 |
+
if len(buf) > 0:
|
| 309 |
+
yield buf
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
| 313 |
+
""" Dynamic batch the data until the total frames in batch
|
| 314 |
+
reach `max_frames_in_batch`
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
data: Iterable[{key, feat, label}]
|
| 318 |
+
max_frames_in_batch: max_frames in one batch
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Iterable[List[{key, feat, label}]]
|
| 322 |
+
"""
|
| 323 |
+
buf = []
|
| 324 |
+
longest_frames = 0
|
| 325 |
+
for sample in data:
|
| 326 |
+
assert 'speech_feat' in sample
|
| 327 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
| 328 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
| 329 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
| 330 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
| 331 |
+
if frames_after_padding > max_frames_in_batch:
|
| 332 |
+
yield buf
|
| 333 |
+
buf = [sample]
|
| 334 |
+
longest_frames = new_sample_frames
|
| 335 |
+
else:
|
| 336 |
+
buf.append(sample)
|
| 337 |
+
if len(buf) > 0:
|
| 338 |
+
yield buf
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
| 342 |
+
""" Wrapper for static/dynamic batch
|
| 343 |
+
"""
|
| 344 |
+
if mode == 'inference':
|
| 345 |
+
return static_batch(data, 1)
|
| 346 |
+
else:
|
| 347 |
+
if batch_type == 'static':
|
| 348 |
+
return static_batch(data, batch_size)
|
| 349 |
+
elif batch_type == 'dynamic':
|
| 350 |
+
return dynamic_batch(data, max_frames_in_batch)
|
| 351 |
+
else:
|
| 352 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def padding(data, use_spk_embedding, mode='train', gan=False):
|
| 356 |
+
""" Padding the data into training data
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
data: Iterable[List[{key, feat, label}]]
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
| 363 |
+
"""
|
| 364 |
+
for sample in data:
|
| 365 |
+
assert isinstance(sample, list)
|
| 366 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
| 367 |
+
dtype=torch.int32)
|
| 368 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
| 369 |
+
|
| 370 |
+
utts = [sample[i]['utt'] for i in order]
|
| 371 |
+
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
| 372 |
+
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
| 373 |
+
speech = pad_sequence(speech, batch_first=True, padding_value=0)
|
| 374 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
| 375 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
| 376 |
+
speech_token = pad_sequence(speech_token,
|
| 377 |
+
batch_first=True,
|
| 378 |
+
padding_value=0)
|
| 379 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
| 380 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
| 381 |
+
speech_feat = pad_sequence(speech_feat,
|
| 382 |
+
batch_first=True,
|
| 383 |
+
padding_value=0)
|
| 384 |
+
text = [sample[i]['text'] for i in order]
|
| 385 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
| 386 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
| 387 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
| 388 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
| 389 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
| 390 |
+
batch = {
|
| 391 |
+
"utts": utts,
|
| 392 |
+
"speech": speech,
|
| 393 |
+
"speech_len": speech_len,
|
| 394 |
+
"speech_token": speech_token,
|
| 395 |
+
"speech_token_len": speech_token_len,
|
| 396 |
+
"speech_feat": speech_feat,
|
| 397 |
+
"speech_feat_len": speech_feat_len,
|
| 398 |
+
"text": text,
|
| 399 |
+
"text_token": text_token,
|
| 400 |
+
"text_token_len": text_token_len,
|
| 401 |
+
"utt_embedding": utt_embedding,
|
| 402 |
+
"spk_embedding": spk_embedding,
|
| 403 |
+
}
|
| 404 |
+
if gan is True:
|
| 405 |
+
# in gan train, we need pitch_feat
|
| 406 |
+
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
| 407 |
+
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
| 408 |
+
pitch_feat = pad_sequence(pitch_feat,
|
| 409 |
+
batch_first=True,
|
| 410 |
+
padding_value=0)
|
| 411 |
+
batch["pitch_feat"] = pitch_feat
|
| 412 |
+
batch["pitch_feat_len"] = pitch_feat_len
|
| 413 |
+
else:
|
| 414 |
+
# only gan train needs speech, delete it to save memory
|
| 415 |
+
del batch["speech"]
|
| 416 |
+
del batch["speech_len"]
|
| 417 |
+
if mode == 'inference':
|
| 418 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
| 419 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
| 420 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
| 421 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
| 422 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
| 423 |
+
batch.update({'tts_text': tts_text,
|
| 424 |
+
'tts_index': tts_index,
|
| 425 |
+
'tts_text_token': tts_text_token,
|
| 426 |
+
'tts_text_token_len': tts_text_token_len})
|
| 427 |
+
if use_spk_embedding is True:
|
| 428 |
+
batch["embedding"] = batch["spk_embedding"]
|
| 429 |
+
else:
|
| 430 |
+
batch["embedding"] = batch["utt_embedding"]
|
| 431 |
+
yield batch
|
cosyvoice/flow/decoder.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from einops import pack, rearrange, repeat
|
| 18 |
+
from cosyvoice.utils.common import mask_to_bias
|
| 19 |
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
| 20 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
| 21 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Transpose(torch.nn.Module):
|
| 25 |
+
def __init__(self, dim0: int, dim1: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.dim0 = dim0
|
| 28 |
+
self.dim1 = dim1
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor):
|
| 31 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CausalBlock1D(Block1D):
|
| 36 |
+
def __init__(self, dim: int, dim_out: int):
|
| 37 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 38 |
+
self.block = torch.nn.Sequential(
|
| 39 |
+
CausalConv1d(dim, dim_out, 3),
|
| 40 |
+
Transpose(1, 2),
|
| 41 |
+
nn.LayerNorm(dim_out),
|
| 42 |
+
Transpose(1, 2),
|
| 43 |
+
nn.Mish(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 47 |
+
output = self.block(x * mask)
|
| 48 |
+
return output * mask
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 52 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 53 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 54 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 55 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
in_channels: int,
|
| 62 |
+
out_channels: int,
|
| 63 |
+
kernel_size: int,
|
| 64 |
+
stride: int = 1,
|
| 65 |
+
dilation: int = 1,
|
| 66 |
+
groups: int = 1,
|
| 67 |
+
bias: bool = True,
|
| 68 |
+
padding_mode: str = 'zeros',
|
| 69 |
+
device=None,
|
| 70 |
+
dtype=None
|
| 71 |
+
) -> None:
|
| 72 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 73 |
+
kernel_size, stride,
|
| 74 |
+
padding=0, dilation=dilation,
|
| 75 |
+
groups=groups, bias=bias,
|
| 76 |
+
padding_mode=padding_mode,
|
| 77 |
+
device=device, dtype=dtype)
|
| 78 |
+
assert stride == 1
|
| 79 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor):
|
| 82 |
+
x = F.pad(x, self.causal_padding)
|
| 83 |
+
x = super(CausalConv1d, self).forward(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ConditionalDecoder(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
in_channels,
|
| 91 |
+
out_channels,
|
| 92 |
+
causal=False,
|
| 93 |
+
channels=(256, 256),
|
| 94 |
+
dropout=0.05,
|
| 95 |
+
attention_head_dim=64,
|
| 96 |
+
n_blocks=1,
|
| 97 |
+
num_mid_blocks=2,
|
| 98 |
+
num_heads=4,
|
| 99 |
+
act_fn="snake",
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 103 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 104 |
+
"""
|
| 105 |
+
super().__init__()
|
| 106 |
+
channels = tuple(channels)
|
| 107 |
+
self.in_channels = in_channels
|
| 108 |
+
self.out_channels = out_channels
|
| 109 |
+
self.causal = causal
|
| 110 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 111 |
+
time_embed_dim = channels[0] * 4
|
| 112 |
+
self.time_mlp = TimestepEmbedding(
|
| 113 |
+
in_channels=in_channels,
|
| 114 |
+
time_embed_dim=time_embed_dim,
|
| 115 |
+
act_fn="silu",
|
| 116 |
+
)
|
| 117 |
+
self.down_blocks = nn.ModuleList([])
|
| 118 |
+
self.mid_blocks = nn.ModuleList([])
|
| 119 |
+
self.up_blocks = nn.ModuleList([])
|
| 120 |
+
|
| 121 |
+
output_channel = in_channels
|
| 122 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 123 |
+
input_channel = output_channel
|
| 124 |
+
output_channel = channels[i]
|
| 125 |
+
is_last = i == len(channels) - 1
|
| 126 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 127 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 128 |
+
transformer_blocks = nn.ModuleList(
|
| 129 |
+
[
|
| 130 |
+
BasicTransformerBlock(
|
| 131 |
+
dim=output_channel,
|
| 132 |
+
num_attention_heads=num_heads,
|
| 133 |
+
attention_head_dim=attention_head_dim,
|
| 134 |
+
dropout=dropout,
|
| 135 |
+
activation_fn=act_fn,
|
| 136 |
+
)
|
| 137 |
+
for _ in range(n_blocks)
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
downsample = (
|
| 141 |
+
Downsample1D(output_channel) if not is_last else
|
| 142 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 143 |
+
)
|
| 144 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 145 |
+
|
| 146 |
+
for _ in range(num_mid_blocks):
|
| 147 |
+
input_channel = channels[-1]
|
| 148 |
+
out_channels = channels[-1]
|
| 149 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 150 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 151 |
+
|
| 152 |
+
transformer_blocks = nn.ModuleList(
|
| 153 |
+
[
|
| 154 |
+
BasicTransformerBlock(
|
| 155 |
+
dim=output_channel,
|
| 156 |
+
num_attention_heads=num_heads,
|
| 157 |
+
attention_head_dim=attention_head_dim,
|
| 158 |
+
dropout=dropout,
|
| 159 |
+
activation_fn=act_fn,
|
| 160 |
+
)
|
| 161 |
+
for _ in range(n_blocks)
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 166 |
+
|
| 167 |
+
channels = channels[::-1] + (channels[0],)
|
| 168 |
+
for i in range(len(channels) - 1):
|
| 169 |
+
input_channel = channels[i] * 2
|
| 170 |
+
output_channel = channels[i + 1]
|
| 171 |
+
is_last = i == len(channels) - 2
|
| 172 |
+
resnet = CausalResnetBlock1D(
|
| 173 |
+
dim=input_channel,
|
| 174 |
+
dim_out=output_channel,
|
| 175 |
+
time_emb_dim=time_embed_dim,
|
| 176 |
+
) if self.causal else ResnetBlock1D(
|
| 177 |
+
dim=input_channel,
|
| 178 |
+
dim_out=output_channel,
|
| 179 |
+
time_emb_dim=time_embed_dim,
|
| 180 |
+
)
|
| 181 |
+
transformer_blocks = nn.ModuleList(
|
| 182 |
+
[
|
| 183 |
+
BasicTransformerBlock(
|
| 184 |
+
dim=output_channel,
|
| 185 |
+
num_attention_heads=num_heads,
|
| 186 |
+
attention_head_dim=attention_head_dim,
|
| 187 |
+
dropout=dropout,
|
| 188 |
+
activation_fn=act_fn,
|
| 189 |
+
)
|
| 190 |
+
for _ in range(n_blocks)
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
upsample = (
|
| 194 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 195 |
+
if not is_last
|
| 196 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 197 |
+
)
|
| 198 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 199 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
| 200 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 201 |
+
self.initialize_weights()
|
| 202 |
+
|
| 203 |
+
def initialize_weights(self):
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv1d):
|
| 206 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 207 |
+
if m.bias is not None:
|
| 208 |
+
nn.init.constant_(m.bias, 0)
|
| 209 |
+
elif isinstance(m, nn.GroupNorm):
|
| 210 |
+
nn.init.constant_(m.weight, 1)
|
| 211 |
+
nn.init.constant_(m.bias, 0)
|
| 212 |
+
elif isinstance(m, nn.Linear):
|
| 213 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 214 |
+
if m.bias is not None:
|
| 215 |
+
nn.init.constant_(m.bias, 0)
|
| 216 |
+
|
| 217 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
| 218 |
+
"""Forward pass of the UNet1DConditional model.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 222 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 223 |
+
t (_type_): shape (batch_size)
|
| 224 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 225 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 226 |
+
|
| 227 |
+
Raises:
|
| 228 |
+
ValueError: _description_
|
| 229 |
+
ValueError: _description_
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
_type_: _description_
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 236 |
+
t = self.time_mlp(t)
|
| 237 |
+
|
| 238 |
+
x = pack([x, mu], "b * t")[0]
|
| 239 |
+
|
| 240 |
+
if spks is not None:
|
| 241 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 242 |
+
x = pack([x, spks], "b * t")[0]
|
| 243 |
+
if cond is not None:
|
| 244 |
+
x = pack([x, cond], "b * t")[0]
|
| 245 |
+
|
| 246 |
+
hiddens = []
|
| 247 |
+
masks = [mask]
|
| 248 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 249 |
+
mask_down = masks[-1]
|
| 250 |
+
x = resnet(x, mask_down, t)
|
| 251 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 252 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
| 253 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 254 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 255 |
+
for transformer_block in transformer_blocks:
|
| 256 |
+
x = transformer_block(
|
| 257 |
+
hidden_states=x,
|
| 258 |
+
attention_mask=attn_mask,
|
| 259 |
+
timestep=t,
|
| 260 |
+
)
|
| 261 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 262 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 263 |
+
x = downsample(x * mask_down)
|
| 264 |
+
masks.append(mask_down[:, :, ::2])
|
| 265 |
+
masks = masks[:-1]
|
| 266 |
+
mask_mid = masks[-1]
|
| 267 |
+
|
| 268 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 269 |
+
x = resnet(x, mask_mid, t)
|
| 270 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 271 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
| 272 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 273 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 274 |
+
for transformer_block in transformer_blocks:
|
| 275 |
+
x = transformer_block(
|
| 276 |
+
hidden_states=x,
|
| 277 |
+
attention_mask=attn_mask,
|
| 278 |
+
timestep=t,
|
| 279 |
+
)
|
| 280 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 281 |
+
|
| 282 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 283 |
+
mask_up = masks.pop()
|
| 284 |
+
skip = hiddens.pop()
|
| 285 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 286 |
+
x = resnet(x, mask_up, t)
|
| 287 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 288 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
| 289 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 290 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 291 |
+
for transformer_block in transformer_blocks:
|
| 292 |
+
x = transformer_block(
|
| 293 |
+
hidden_states=x,
|
| 294 |
+
attention_mask=attn_mask,
|
| 295 |
+
timestep=t,
|
| 296 |
+
)
|
| 297 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 298 |
+
x = upsample(x * mask_up)
|
| 299 |
+
x = self.final_block(x, mask_up)
|
| 300 |
+
output = self.final_proj(x * mask_up)
|
| 301 |
+
return output * mask
|
cosyvoice/flow/flow.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import random
|
| 16 |
+
from typing import Dict, Optional
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
input_size: int = 512,
|
| 27 |
+
output_size: int = 80,
|
| 28 |
+
spk_embed_dim: int = 192,
|
| 29 |
+
output_type: str = "mel",
|
| 30 |
+
vocab_size: int = 4096,
|
| 31 |
+
input_frame_rate: int = 50,
|
| 32 |
+
only_mask_loss: bool = True,
|
| 33 |
+
encoder: torch.nn.Module = None,
|
| 34 |
+
length_regulator: torch.nn.Module = None,
|
| 35 |
+
decoder: torch.nn.Module = None,
|
| 36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.input_size = input_size
|
| 45 |
+
self.output_size = output_size
|
| 46 |
+
self.decoder_conf = decoder_conf
|
| 47 |
+
self.mel_feat_conf = mel_feat_conf
|
| 48 |
+
self.vocab_size = vocab_size
|
| 49 |
+
self.output_type = output_type
|
| 50 |
+
self.input_frame_rate = input_frame_rate
|
| 51 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 52 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 53 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 54 |
+
self.encoder = encoder
|
| 55 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 56 |
+
self.decoder = decoder
|
| 57 |
+
self.length_regulator = length_regulator
|
| 58 |
+
self.only_mask_loss = only_mask_loss
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
batch: dict,
|
| 63 |
+
device: torch.device,
|
| 64 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 65 |
+
token = batch['speech_token'].to(device)
|
| 66 |
+
token_len = batch['speech_token_len'].to(device)
|
| 67 |
+
feat = batch['speech_feat'].to(device)
|
| 68 |
+
feat_len = batch['speech_feat_len'].to(device)
|
| 69 |
+
embedding = batch['embedding'].to(device)
|
| 70 |
+
|
| 71 |
+
# xvec projection
|
| 72 |
+
embedding = F.normalize(embedding, dim=1)
|
| 73 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 74 |
+
|
| 75 |
+
# concat text and prompt_text
|
| 76 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
| 77 |
+
# Clamp tokens to valid vocabulary range
|
| 78 |
+
token = torch.clamp(token, min=0, max=self.vocab_size - 1)
|
| 79 |
+
token = self.input_embedding(token) * mask
|
| 80 |
+
|
| 81 |
+
# text encode
|
| 82 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 83 |
+
h = self.encoder_proj(h)
|
| 84 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
| 85 |
+
|
| 86 |
+
# get conditions
|
| 87 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
| 88 |
+
for i, j in enumerate(feat_len):
|
| 89 |
+
if random.random() < 0.5:
|
| 90 |
+
continue
|
| 91 |
+
index = random.randint(0, int(0.3 * j))
|
| 92 |
+
conds[i, :index] = feat[i, :index]
|
| 93 |
+
conds = conds.transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
| 96 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
| 97 |
+
loss, _ = self.decoder.compute_loss(
|
| 98 |
+
feat.transpose(1, 2).contiguous(),
|
| 99 |
+
mask.unsqueeze(1),
|
| 100 |
+
h.transpose(1, 2).contiguous(),
|
| 101 |
+
embedding,
|
| 102 |
+
cond=conds
|
| 103 |
+
)
|
| 104 |
+
return {'loss': loss}
|
| 105 |
+
|
| 106 |
+
@torch.inference_mode()
|
| 107 |
+
def inference(self,
|
| 108 |
+
token,
|
| 109 |
+
token_len,
|
| 110 |
+
prompt_token,
|
| 111 |
+
prompt_token_len,
|
| 112 |
+
prompt_feat,
|
| 113 |
+
prompt_feat_len,
|
| 114 |
+
embedding,
|
| 115 |
+
flow_cache):
|
| 116 |
+
assert token.shape[0] == 1
|
| 117 |
+
# xvec projection
|
| 118 |
+
embedding = F.normalize(embedding, dim=1)
|
| 119 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 120 |
+
|
| 121 |
+
# concat text and prompt_text
|
| 122 |
+
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
| 123 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 124 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 125 |
+
# Clamp tokens to valid vocabulary range
|
| 126 |
+
token = torch.clamp(token, min=0, max=self.vocab_size - 1)
|
| 127 |
+
token = self.input_embedding(token) * mask
|
| 128 |
+
|
| 129 |
+
# text encode
|
| 130 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 131 |
+
h = self.encoder_proj(h)
|
| 132 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
| 133 |
+
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
| 134 |
+
|
| 135 |
+
# get conditions
|
| 136 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
| 137 |
+
conds[:, :mel_len1] = prompt_feat
|
| 138 |
+
conds = conds.transpose(1, 2)
|
| 139 |
+
|
| 140 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 141 |
+
feat, flow_cache = self.decoder(
|
| 142 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 143 |
+
mask=mask.unsqueeze(1),
|
| 144 |
+
spks=embedding,
|
| 145 |
+
cond=conds,
|
| 146 |
+
n_timesteps=10,
|
| 147 |
+
prompt_len=mel_len1,
|
| 148 |
+
flow_cache=flow_cache
|
| 149 |
+
)
|
| 150 |
+
feat = feat[:, :, mel_len1:]
|
| 151 |
+
assert feat.shape[2] == mel_len2
|
| 152 |
+
return feat, flow_cache
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 156 |
+
def __init__(self,
|
| 157 |
+
input_size: int = 512,
|
| 158 |
+
output_size: int = 80,
|
| 159 |
+
spk_embed_dim: int = 192,
|
| 160 |
+
output_type: str = "mel",
|
| 161 |
+
vocab_size: int = 4096,
|
| 162 |
+
input_frame_rate: int = 50,
|
| 163 |
+
only_mask_loss: bool = True,
|
| 164 |
+
token_mel_ratio: int = 2,
|
| 165 |
+
pre_lookahead_len: int = 3,
|
| 166 |
+
encoder: torch.nn.Module = None,
|
| 167 |
+
decoder: torch.nn.Module = None,
|
| 168 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
| 169 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
| 170 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
| 171 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
| 172 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
| 173 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
| 174 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.input_size = input_size
|
| 177 |
+
self.output_size = output_size
|
| 178 |
+
self.decoder_conf = decoder_conf
|
| 179 |
+
self.mel_feat_conf = mel_feat_conf
|
| 180 |
+
self.vocab_size = vocab_size
|
| 181 |
+
self.output_type = output_type
|
| 182 |
+
self.input_frame_rate = input_frame_rate
|
| 183 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 184 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 185 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 186 |
+
self.encoder = encoder
|
| 187 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 188 |
+
self.decoder = decoder
|
| 189 |
+
self.only_mask_loss = only_mask_loss
|
| 190 |
+
self.token_mel_ratio = token_mel_ratio
|
| 191 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 192 |
+
|
| 193 |
+
@torch.inference_mode()
|
| 194 |
+
def inference(self,
|
| 195 |
+
token,
|
| 196 |
+
token_len,
|
| 197 |
+
prompt_token,
|
| 198 |
+
prompt_token_len,
|
| 199 |
+
prompt_feat,
|
| 200 |
+
prompt_feat_len,
|
| 201 |
+
embedding,
|
| 202 |
+
finalize):
|
| 203 |
+
assert token.shape[0] == 1
|
| 204 |
+
# xvec projection
|
| 205 |
+
embedding = F.normalize(embedding, dim=1)
|
| 206 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 207 |
+
|
| 208 |
+
# concat text and prompt_text
|
| 209 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
| 210 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
| 211 |
+
# Clamp tokens to valid vocabulary range
|
| 212 |
+
token = torch.clamp(token, min=0, max=self.vocab_size - 1)
|
| 213 |
+
token = self.input_embedding(token) * mask
|
| 214 |
+
|
| 215 |
+
# text encode
|
| 216 |
+
h, h_lengths = self.encoder(token, token_len)
|
| 217 |
+
if finalize is False:
|
| 218 |
+
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
| 219 |
+
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
| 220 |
+
h = self.encoder_proj(h)
|
| 221 |
+
|
| 222 |
+
# get conditions
|
| 223 |
+
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
| 224 |
+
conds[:, :mel_len1] = prompt_feat
|
| 225 |
+
conds = conds.transpose(1, 2)
|
| 226 |
+
|
| 227 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
| 228 |
+
feat, _ = self.decoder(
|
| 229 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 230 |
+
mask=mask.unsqueeze(1),
|
| 231 |
+
spks=embedding,
|
| 232 |
+
cond=conds,
|
| 233 |
+
n_timesteps=10
|
| 234 |
+
)
|
| 235 |
+
feat = feat[:, :, mel_len1:]
|
| 236 |
+
assert feat.shape[2] == mel_len2
|
| 237 |
+
return feat, None
|
cosyvoice/flow/flow_matching.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import onnxruntime
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from matcha.models.components.flow_matching import BASECFM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ConditionalCFM(BASECFM):
|
| 21 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 22 |
+
super().__init__(
|
| 23 |
+
n_feats=in_channels,
|
| 24 |
+
cfm_params=cfm_params,
|
| 25 |
+
n_spks=n_spks,
|
| 26 |
+
spk_emb_dim=spk_emb_dim,
|
| 27 |
+
)
|
| 28 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 29 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 30 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 31 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 32 |
+
# Just change the architecture of the estimator here
|
| 33 |
+
self.estimator = estimator
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
| 37 |
+
"""Forward diffusion
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
mu (torch.Tensor): output of encoder
|
| 41 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 42 |
+
mask (torch.Tensor): output_mask
|
| 43 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 44 |
+
n_timesteps (int): number of diffusion steps
|
| 45 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 46 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 47 |
+
shape: (batch_size, spk_emb_dim)
|
| 48 |
+
cond: Not used but kept for future purposes
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
sample: generated mel-spectrogram
|
| 52 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
z = torch.randn_like(mu) * temperature
|
| 56 |
+
# Handle None flow_cache
|
| 57 |
+
if flow_cache is not None:
|
| 58 |
+
cache_size = flow_cache.shape[2]
|
| 59 |
+
# fix prompt and overlap part mu and z
|
| 60 |
+
if cache_size != 0:
|
| 61 |
+
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
|
| 62 |
+
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
|
| 63 |
+
else:
|
| 64 |
+
cache_size = 0
|
| 65 |
+
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
| 66 |
+
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
| 67 |
+
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
|
| 68 |
+
|
| 69 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 70 |
+
if self.t_scheduler == 'cosine':
|
| 71 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 72 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
| 73 |
+
|
| 74 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 75 |
+
"""
|
| 76 |
+
Fixed euler solver for ODEs.
|
| 77 |
+
Args:
|
| 78 |
+
x (torch.Tensor): random noise
|
| 79 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 80 |
+
shape: (n_timesteps + 1,)
|
| 81 |
+
mu (torch.Tensor): output of encoder
|
| 82 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 83 |
+
mask (torch.Tensor): output_mask
|
| 84 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 85 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 86 |
+
shape: (batch_size, spk_emb_dim)
|
| 87 |
+
cond: Not used but kept for future purposes
|
| 88 |
+
"""
|
| 89 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 90 |
+
t = t.unsqueeze(dim=0)
|
| 91 |
+
|
| 92 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 93 |
+
# Or in future might add like a return_all_steps flag
|
| 94 |
+
sol = []
|
| 95 |
+
|
| 96 |
+
if self.inference_cfg_rate > 0:
|
| 97 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 98 |
+
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 99 |
+
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
| 100 |
+
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 101 |
+
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
| 102 |
+
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
| 103 |
+
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
| 104 |
+
else:
|
| 105 |
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
| 106 |
+
for step in range(1, len(t_span)):
|
| 107 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 108 |
+
if self.inference_cfg_rate > 0:
|
| 109 |
+
x_in[:] = x
|
| 110 |
+
mask_in[:] = mask
|
| 111 |
+
mu_in[0] = mu
|
| 112 |
+
t_in[:] = t.unsqueeze(0)
|
| 113 |
+
spks_in[0] = spks
|
| 114 |
+
cond_in[0] = cond
|
| 115 |
+
else:
|
| 116 |
+
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
| 117 |
+
dphi_dt = self.forward_estimator(
|
| 118 |
+
x_in, mask_in,
|
| 119 |
+
mu_in, t_in,
|
| 120 |
+
spks_in,
|
| 121 |
+
cond_in
|
| 122 |
+
)
|
| 123 |
+
if self.inference_cfg_rate > 0:
|
| 124 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
| 125 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 126 |
+
x = x + dt * dphi_dt
|
| 127 |
+
t = t + dt
|
| 128 |
+
sol.append(x)
|
| 129 |
+
if step < len(t_span) - 1:
|
| 130 |
+
dt = t_span[step + 1] - t
|
| 131 |
+
|
| 132 |
+
return sol[-1].float()
|
| 133 |
+
|
| 134 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
| 135 |
+
if isinstance(self.estimator, torch.nn.Module):
|
| 136 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| 137 |
+
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
| 138 |
+
ort_inputs = {
|
| 139 |
+
'x': x.cpu().numpy(),
|
| 140 |
+
'mask': mask.cpu().numpy(),
|
| 141 |
+
'mu': mu.cpu().numpy(),
|
| 142 |
+
't': t.cpu().numpy(),
|
| 143 |
+
'spks': spks.cpu().numpy(),
|
| 144 |
+
'cond': cond.cpu().numpy()
|
| 145 |
+
}
|
| 146 |
+
output = self.estimator.run(None, ort_inputs)[0]
|
| 147 |
+
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
| 148 |
+
else:
|
| 149 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 150 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 151 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 152 |
+
self.estimator.set_input_shape('t', (2,))
|
| 153 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 154 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 155 |
+
# run trt engine
|
| 156 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 157 |
+
mask.contiguous().data_ptr(),
|
| 158 |
+
mu.contiguous().data_ptr(),
|
| 159 |
+
t.contiguous().data_ptr(),
|
| 160 |
+
spks.contiguous().data_ptr(),
|
| 161 |
+
cond.contiguous().data_ptr(),
|
| 162 |
+
x.data_ptr()])
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 166 |
+
"""Computes diffusion loss
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
x1 (torch.Tensor): Target
|
| 170 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 171 |
+
mask (torch.Tensor): target mask
|
| 172 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 173 |
+
mu (torch.Tensor): output of encoder
|
| 174 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 175 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 176 |
+
shape: (batch_size, spk_emb_dim)
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
loss: conditional flow matching loss
|
| 180 |
+
y: conditional flow
|
| 181 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 182 |
+
"""
|
| 183 |
+
b, _, t = mu.shape
|
| 184 |
+
|
| 185 |
+
# random timestep
|
| 186 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 187 |
+
if self.t_scheduler == 'cosine':
|
| 188 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 189 |
+
# sample noise p(x_0)
|
| 190 |
+
z = torch.randn_like(x1)
|
| 191 |
+
|
| 192 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 193 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 194 |
+
|
| 195 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 196 |
+
if self.training_cfg_rate > 0:
|
| 197 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 198 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 199 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 200 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 201 |
+
|
| 202 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| 203 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 204 |
+
return loss, y
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class CausalConditionalCFM(ConditionalCFM):
|
| 208 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
| 209 |
+
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
| 210 |
+
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
| 211 |
+
|
| 212 |
+
@torch.inference_mode()
|
| 213 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 214 |
+
"""Forward diffusion
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
mu (torch.Tensor): output of encoder
|
| 218 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 219 |
+
mask (torch.Tensor): output_mask
|
| 220 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 221 |
+
n_timesteps (int): number of diffusion steps
|
| 222 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 223 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 224 |
+
shape: (batch_size, spk_emb_dim)
|
| 225 |
+
cond: Not used but kept for future purposes
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
sample: generated mel-spectrogram
|
| 229 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
| 233 |
+
if self.fp16 is True:
|
| 234 |
+
z = z.half()
|
| 235 |
+
# fix prompt and overlap part mu and z
|
| 236 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 237 |
+
if self.t_scheduler == 'cosine':
|
| 238 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 239 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
cosyvoice/flow/length_regulator.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from cosyvoice.utils.mask import make_pad_mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class InterpolateRegulator(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
channels: int,
|
| 25 |
+
sampling_ratios: Tuple,
|
| 26 |
+
out_channels: int = None,
|
| 27 |
+
groups: int = 1,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.sampling_ratios = sampling_ratios
|
| 31 |
+
out_channels = out_channels or channels
|
| 32 |
+
model = nn.ModuleList([])
|
| 33 |
+
if len(sampling_ratios) > 0:
|
| 34 |
+
for _ in sampling_ratios:
|
| 35 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| 36 |
+
norm = nn.GroupNorm(groups, channels)
|
| 37 |
+
act = nn.Mish()
|
| 38 |
+
model.extend([module, norm, act])
|
| 39 |
+
model.append(
|
| 40 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
| 41 |
+
)
|
| 42 |
+
self.model = nn.Sequential(*model)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, ylens=None):
|
| 45 |
+
# x in (B, T, D)
|
| 46 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
| 47 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
| 48 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 49 |
+
olens = ylens
|
| 50 |
+
return out * mask, olens
|
| 51 |
+
|
| 52 |
+
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
| 53 |
+
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
| 54 |
+
# x in (B, T, D)
|
| 55 |
+
if x2.shape[1] > 40:
|
| 56 |
+
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 57 |
+
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
| 58 |
+
mode='linear')
|
| 59 |
+
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
| 60 |
+
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
| 61 |
+
else:
|
| 62 |
+
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
| 63 |
+
if x1.shape[1] != 0:
|
| 64 |
+
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
| 65 |
+
x = torch.concat([x1, x2], dim=2)
|
| 66 |
+
else:
|
| 67 |
+
x = x2
|
| 68 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 69 |
+
return out, mel_len1 + mel_len2
|
cosyvoice/hifigan/discriminator.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.utils import weight_norm
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torchaudio.transforms import Spectrogram
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MultipleDiscriminator(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, mpd: nn.Module, mrd: nn.Module
|
| 12 |
+
):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.mpd = mpd
|
| 15 |
+
self.mrd = mrd
|
| 16 |
+
|
| 17 |
+
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
| 18 |
+
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
| 19 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
| 20 |
+
y_d_rs += this_y_d_rs
|
| 21 |
+
y_d_gs += this_y_d_gs
|
| 22 |
+
fmap_rs += this_fmap_rs
|
| 23 |
+
fmap_gs += this_fmap_gs
|
| 24 |
+
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
| 25 |
+
y_d_rs += this_y_d_rs
|
| 26 |
+
y_d_gs += this_y_d_gs
|
| 27 |
+
fmap_rs += this_fmap_rs
|
| 28 |
+
fmap_gs += this_fmap_gs
|
| 29 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
| 36 |
+
num_embeddings: Optional[int] = None,
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
| 40 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
| 44 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 45 |
+
Defaults to None.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.discriminators = nn.ModuleList(
|
| 50 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 55 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 56 |
+
y_d_rs = []
|
| 57 |
+
y_d_gs = []
|
| 58 |
+
fmap_rs = []
|
| 59 |
+
fmap_gs = []
|
| 60 |
+
|
| 61 |
+
for d in self.discriminators:
|
| 62 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 63 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 64 |
+
y_d_rs.append(y_d_r)
|
| 65 |
+
fmap_rs.append(fmap_r)
|
| 66 |
+
y_d_gs.append(y_d_g)
|
| 67 |
+
fmap_gs.append(fmap_g)
|
| 68 |
+
|
| 69 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DiscriminatorR(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
window_length: int,
|
| 76 |
+
num_embeddings: Optional[int] = None,
|
| 77 |
+
channels: int = 32,
|
| 78 |
+
hop_factor: float = 0.25,
|
| 79 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.window_length = window_length
|
| 83 |
+
self.hop_factor = hop_factor
|
| 84 |
+
self.spec_fn = Spectrogram(
|
| 85 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
| 86 |
+
)
|
| 87 |
+
n_fft = window_length // 2 + 1
|
| 88 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 89 |
+
self.bands = bands
|
| 90 |
+
convs = lambda: nn.ModuleList(
|
| 91 |
+
[
|
| 92 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
| 93 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 94 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 95 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 96 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 100 |
+
|
| 101 |
+
if num_embeddings is not None:
|
| 102 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
| 103 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 104 |
+
|
| 105 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
| 106 |
+
|
| 107 |
+
def spectrogram(self, x):
|
| 108 |
+
# Remove DC offset
|
| 109 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
| 110 |
+
# Peak normalize the volume of input audio
|
| 111 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 112 |
+
x = self.spec_fn(x)
|
| 113 |
+
x = torch.view_as_real(x)
|
| 114 |
+
x = rearrange(x, "b f t c -> b c t f")
|
| 115 |
+
# Split into bands
|
| 116 |
+
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
| 117 |
+
return x_bands
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
| 120 |
+
x_bands = self.spectrogram(x)
|
| 121 |
+
fmap = []
|
| 122 |
+
x = []
|
| 123 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 124 |
+
for i, layer in enumerate(stack):
|
| 125 |
+
band = layer(band)
|
| 126 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
| 127 |
+
if i > 0:
|
| 128 |
+
fmap.append(band)
|
| 129 |
+
x.append(band)
|
| 130 |
+
x = torch.cat(x, dim=-1)
|
| 131 |
+
if cond_embedding_id is not None:
|
| 132 |
+
emb = self.emb(cond_embedding_id)
|
| 133 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 134 |
+
else:
|
| 135 |
+
h = 0
|
| 136 |
+
x = self.conv_post(x)
|
| 137 |
+
fmap.append(x)
|
| 138 |
+
x += h
|
| 139 |
+
|
| 140 |
+
return x, fmap
|