Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -0,0 +1,1395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time, random
|
| 4 |
+
from random import choice
|
| 5 |
+
from typing import List, Dict
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import music21
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import wave
|
| 14 |
+
import struct
|
| 15 |
+
import ffmpeg
|
| 16 |
+
import tempfile
|
| 17 |
+
from pydub import AudioSegment
|
| 18 |
+
from moviepy.editor import VideoFileClip, AudioFileClip
|
| 19 |
+
from torch.utils.data import Dataset, DataLoader
|
| 20 |
+
from torch.utils.data import DataLoader, SubsetRandomSampler
|
| 21 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 22 |
+
from torch.nn.utils import clip_grad_norm_
|
| 23 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 24 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
+
|
| 26 |
+
# 设置基础路径
|
| 27 |
+
Gbase = "./"
|
| 28 |
+
cache_dir = "./hf/"
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import google.colab
|
| 32 |
+
from google.colab import drive
|
| 33 |
+
|
| 34 |
+
IN_COLAB = True
|
| 35 |
+
drive.mount('/gdrive', force_remount=True)
|
| 36 |
+
Gbase = "/gdrive/MyDrive/generate/"
|
| 37 |
+
cache_dir = "/gdrive/MyDrive/hf/"
|
| 38 |
+
sys.path.append(Gbase)
|
| 39 |
+
except:
|
| 40 |
+
IN_COLAB = False
|
| 41 |
+
Gbase = "./"
|
| 42 |
+
cache_dir = "./hf/"
|
| 43 |
+
|
| 44 |
+
# 定义模型保存路径
|
| 45 |
+
ModelPath = os.path.join(Gbase, 'music_generation_model.pth')
|
| 46 |
+
OptimizerPath = os.path.join(Gbase, 'optimizer_state.pth')
|
| 47 |
+
DiscriminatorModelPath = os.path.join(Gbase, 'discriminator_model.pth')
|
| 48 |
+
DiscriminatorOptimizerPath = os.path.join(Gbase, 'discriminator_optimizer_state.pth')
|
| 49 |
+
EvaluatorPath = os.path.join(Gbase, 'music_tag_evaluator.pkl')
|
| 50 |
+
|
| 51 |
+
# 定义音乐标签
|
| 52 |
+
MUSIC_TAGS = {
|
| 53 |
+
'emotions': ['Happy', 'Sad', 'Angry', 'Peaceful', 'Neutral'],
|
| 54 |
+
'genres': ['Classical', 'Jazz', 'Rock', 'Electronic'],
|
| 55 |
+
'tempo': ['Slow', 'Medium', 'Fast'],
|
| 56 |
+
'instrumentation': ['Piano', 'Guitar', 'Synthesizer'],
|
| 57 |
+
'harmony': ['Consonant', 'Dissonant', 'Complex', 'Simple'],
|
| 58 |
+
'dynamics': ['Dynamic', 'Static'],
|
| 59 |
+
'rhythm': ['Simple', 'Complex']
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def randomMusicTags():
|
| 63 |
+
return {k: choice(MUSIC_TAGS[k]) for k in MUSIC_TAGS.keys()}
|
| 64 |
+
|
| 65 |
+
print("随机生成的音乐标签:", randomMusicTags())
|
| 66 |
+
|
| 67 |
+
def get_scale_notes(key_str: str, octave_range=(2, 6)) -> List[int]:
|
| 68 |
+
"""
|
| 69 |
+
根据调性返回所属音阶的 MIDI 音高列表。
|
| 70 |
+
"""
|
| 71 |
+
key = music21.key.Key(key_str)
|
| 72 |
+
scale_notes = []
|
| 73 |
+
for octave in range(octave_range[0], octave_range[1] + 1):
|
| 74 |
+
pitches = key.getScale().getPitches(f"{key_str}{octave}")
|
| 75 |
+
for pitch in pitches:
|
| 76 |
+
scale_notes.append(pitch.midi)
|
| 77 |
+
return scale_notes
|
| 78 |
+
|
| 79 |
+
def composer_from_features(features: np.ndarray, key_str: str) -> music21.stream.Stream:
|
| 80 |
+
"""
|
| 81 |
+
将特征转换为 music21.stream.Stream 对象,并确保音符遵循指定音阶。
|
| 82 |
+
"""
|
| 83 |
+
s = music21.stream.Stream()
|
| 84 |
+
|
| 85 |
+
# 设置节奏(BPM),默认 120 BPM
|
| 86 |
+
tempo = music21.tempo.MetronomeMark(number=120)
|
| 87 |
+
s.append(tempo)
|
| 88 |
+
|
| 89 |
+
# 设置调性
|
| 90 |
+
tonality = music21.key.Key(key_str)
|
| 91 |
+
s.append(tonality)
|
| 92 |
+
|
| 93 |
+
# 获取音阶音符
|
| 94 |
+
scale_notes = get_scale_notes(key_str)
|
| 95 |
+
|
| 96 |
+
# 定义可接受的时值
|
| 97 |
+
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
|
| 98 |
+
|
| 99 |
+
for feature in features:
|
| 100 |
+
pitch = int(round(feature[0]))
|
| 101 |
+
duration = feature[1]
|
| 102 |
+
volume = feature[2]
|
| 103 |
+
|
| 104 |
+
# 将时值量化为最近的可接受值
|
| 105 |
+
duration = min(acceptable_durations, key=lambda x: abs(x - duration))
|
| 106 |
+
|
| 107 |
+
# 确保音高在 21 (A0) 到 108 (C8) 之间
|
| 108 |
+
pitch = max(21, min(108, pitch))
|
| 109 |
+
|
| 110 |
+
# 将音高映射到最近的音阶音符
|
| 111 |
+
if pitch not in scale_notes:
|
| 112 |
+
pitch = min(scale_notes, key=lambda x: abs(x - pitch))
|
| 113 |
+
|
| 114 |
+
# 确保音量在 0 到 127 之间
|
| 115 |
+
volume = max(0, min(127, volume))
|
| 116 |
+
|
| 117 |
+
if pitch == 0:
|
| 118 |
+
# 休止符
|
| 119 |
+
r = music21.note.Rest(quarterLength=duration)
|
| 120 |
+
s.append(r)
|
| 121 |
+
else:
|
| 122 |
+
n = music21.note.Note(midi=pitch, quarterLength=duration)
|
| 123 |
+
n.volume.velocity = volume
|
| 124 |
+
s.append(n)
|
| 125 |
+
return s
|
| 126 |
+
|
| 127 |
+
import pickle
|
| 128 |
+
|
| 129 |
+
class MusicTagEvaluator:
|
| 130 |
+
def __init__(self):
|
| 131 |
+
# 定义所有标签
|
| 132 |
+
self.MUSIC_TAGS = MUSIC_TAGS
|
| 133 |
+
# 展平成所有标签并移除重复项
|
| 134 |
+
all_tags = []
|
| 135 |
+
for category in self.MUSIC_TAGS:
|
| 136 |
+
all_tags.extend(self.MUSIC_TAGS[category])
|
| 137 |
+
self.all_tags = list(set(all_tags)) # 移除重复的标签
|
| 138 |
+
self.mlb = MultiLabelBinarizer()
|
| 139 |
+
self.mlb.fit([self.all_tags])
|
| 140 |
+
|
| 141 |
+
def save(self, path):
|
| 142 |
+
with open(path, 'wb') as f:
|
| 143 |
+
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 144 |
+
print(f"评估器已保存至 '{path}'。")
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def load(path):
|
| 148 |
+
if os.path.exists(path):
|
| 149 |
+
with open(path, 'rb') as f:
|
| 150 |
+
evaluator = pickle.load(f)
|
| 151 |
+
print(f"评估器已从 '{path}' 加载。")
|
| 152 |
+
return evaluator
|
| 153 |
+
else:
|
| 154 |
+
print(f"评估器文件 '{path}' 不存在,将创建新的评估器。")
|
| 155 |
+
return MusicTagEvaluator()
|
| 156 |
+
|
| 157 |
+
def evaluate_tags_from_features(self, features: np.ndarray) -> List[str]:
|
| 158 |
+
"""
|
| 159 |
+
根据特征评估标签。
|
| 160 |
+
"""
|
| 161 |
+
# 随机选择一个调性以生成音乐
|
| 162 |
+
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
|
| 163 |
+
s = composer_from_features(features, key_str)
|
| 164 |
+
tag_scores = self.evaluate_tags(s)
|
| 165 |
+
tags = []
|
| 166 |
+
# 根据评分分配标签
|
| 167 |
+
for category in self.MUSIC_TAGS:
|
| 168 |
+
tag = tag_scores.get(category)
|
| 169 |
+
if tag in self.MUSIC_TAGS[category]:
|
| 170 |
+
tags.append(tag)
|
| 171 |
+
return tags
|
| 172 |
+
|
| 173 |
+
def evaluate_tags(self, generated_music):
|
| 174 |
+
"""
|
| 175 |
+
根据生成的音乐评估标签。
|
| 176 |
+
"""
|
| 177 |
+
tag_scores = {}
|
| 178 |
+
|
| 179 |
+
# 音高范围计算
|
| 180 |
+
pitch_values = [note.pitch.midi for note in generated_music.recurse().notes if isinstance(note, music21.note.Note)]
|
| 181 |
+
pitch_range = max(pitch_values) - min(pitch_values) if pitch_values else 0
|
| 182 |
+
|
| 183 |
+
# 单独评估各项
|
| 184 |
+
harmony_tag = self._evaluate_harmony(generated_music)
|
| 185 |
+
rhythm_tag = self._evaluate_rhythm(generated_music)
|
| 186 |
+
dynamics_tag = self._evaluate_dynamics(generated_music)
|
| 187 |
+
tempo_tag = self._evaluate_tempo(generated_music)
|
| 188 |
+
emotion_tag = self._evaluate_emotion(harmony_tag, rhythm_tag, dynamics_tag, tempo_tag)
|
| 189 |
+
|
| 190 |
+
# 标签集合
|
| 191 |
+
tag_scores['emotions'] = emotion_tag
|
| 192 |
+
tag_scores['harmony'] = harmony_tag
|
| 193 |
+
tag_scores['rhythm'] = rhythm_tag
|
| 194 |
+
tag_scores['dynamics'] = dynamics_tag
|
| 195 |
+
tag_scores['tempo'] = tempo_tag
|
| 196 |
+
|
| 197 |
+
return tag_scores
|
| 198 |
+
|
| 199 |
+
def _evaluate_harmony(self, stream):
|
| 200 |
+
# 将音乐流和弦化
|
| 201 |
+
chords = stream.chordify()
|
| 202 |
+
chord_types = []
|
| 203 |
+
for element in chords.recurse():
|
| 204 |
+
if isinstance(element, music21.chord.Chord):
|
| 205 |
+
chord_types.append(element.commonName)
|
| 206 |
+
|
| 207 |
+
# 根据和弦种类评估和声复杂度
|
| 208 |
+
if any('diminished' in str(ct) or 'augmented' in str(ct) for ct in chord_types):
|
| 209 |
+
harmony_tag = 'Complex'
|
| 210 |
+
elif any('major' in str(ct) or 'minor' in str(ct) for ct in chord_types):
|
| 211 |
+
harmony_tag = 'Consonant'
|
| 212 |
+
else:
|
| 213 |
+
harmony_tag = 'Simple'
|
| 214 |
+
|
| 215 |
+
return harmony_tag
|
| 216 |
+
|
| 217 |
+
def _evaluate_rhythm(self, stream):
|
| 218 |
+
durations = [note.quarterLength for note in stream.flat.notes]
|
| 219 |
+
# 计算节奏复杂度,如时值种类的数量
|
| 220 |
+
unique_durations = len(set(durations))
|
| 221 |
+
|
| 222 |
+
if unique_durations > 5:
|
| 223 |
+
rhythm_tag = 'Complex'
|
| 224 |
+
else:
|
| 225 |
+
rhythm_tag = 'Simple'
|
| 226 |
+
|
| 227 |
+
return rhythm_tag
|
| 228 |
+
|
| 229 |
+
def _evaluate_dynamics(self, stream):
|
| 230 |
+
volumes = [note.volume.velocity for note in stream.flat.notes if note.volume.velocity is not None]
|
| 231 |
+
|
| 232 |
+
if not volumes:
|
| 233 |
+
dynamics_tag = 'Static'
|
| 234 |
+
else:
|
| 235 |
+
dynamics_range = max(volumes) - min(volumes)
|
| 236 |
+
if dynamics_range > 40:
|
| 237 |
+
dynamics_tag = 'Dynamic'
|
| 238 |
+
else:
|
| 239 |
+
dynamics_tag = 'Static'
|
| 240 |
+
|
| 241 |
+
return dynamics_tag
|
| 242 |
+
|
| 243 |
+
def _evaluate_tempo(self, stream):
|
| 244 |
+
tempos = [metronome.number for metronome in stream.recurse() if isinstance(metronome, music21.tempo.MetronomeMark)]
|
| 245 |
+
bpm = tempos[0] if tempos else 120 # 默认 BPM 为 120
|
| 246 |
+
|
| 247 |
+
if bpm < 60:
|
| 248 |
+
return 'Slow'
|
| 249 |
+
elif 60 <= bpm < 120:
|
| 250 |
+
return 'Medium'
|
| 251 |
+
else:
|
| 252 |
+
return 'Fast'
|
| 253 |
+
|
| 254 |
+
def _evaluate_emotion(self, harmony_tag, rhythm_tag, dynamics_tag, tempo_tag):
|
| 255 |
+
# 根据和声、节奏、动态和节奏进行情感评估
|
| 256 |
+
if harmony_tag == 'Complex' and rhythm_tag == 'Complex':
|
| 257 |
+
emotion = 'Angry'
|
| 258 |
+
elif harmony_tag == 'Consonant' and dynamics_tag == 'Dynamic' and tempo_tag == 'Fast':
|
| 259 |
+
emotion = 'Happy'
|
| 260 |
+
elif harmony_tag == 'Simple' and dynamics_tag == 'Static' and tempo_tag == 'Slow':
|
| 261 |
+
emotion = 'Peaceful'
|
| 262 |
+
elif harmony_tag == 'Consonant' and dynamics_tag == 'Static' and tempo_tag == 'Medium':
|
| 263 |
+
emotion = 'Neutral'
|
| 264 |
+
else:
|
| 265 |
+
emotion = 'Sad'
|
| 266 |
+
|
| 267 |
+
return emotion
|
| 268 |
+
|
| 269 |
+
class PositionalEncoding(nn.Module):
|
| 270 |
+
def __init__(self, d_model, max_len=5000):
|
| 271 |
+
super(PositionalEncoding, self).__init__()
|
| 272 |
+
|
| 273 |
+
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
|
| 274 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
|
| 275 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) # [d_model/2]
|
| 276 |
+
|
| 277 |
+
pe[:, 0::2] = torch.sin(position * div_term) # even indices
|
| 278 |
+
pe[:, 1::2] = torch.cos(position * div_term) # odd indices
|
| 279 |
+
|
| 280 |
+
pe = pe.unsqueeze(0) # [1, max_len, d_model]
|
| 281 |
+
self.register_buffer('pe', pe)
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
"""
|
| 285 |
+
x: [batch_size, seq_len, d_model]
|
| 286 |
+
"""
|
| 287 |
+
x = x + self.pe[:, :x.size(1), :]
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
class MusicGenerationModel(nn.Module):
|
| 291 |
+
def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags, max_seq_length=500):
|
| 292 |
+
super(MusicGenerationModel, self).__init__()
|
| 293 |
+
self.d_model = d_model
|
| 294 |
+
self.input_linear = nn.Linear(input_dim, d_model)
|
| 295 |
+
self.positional_encoding = PositionalEncoding(d_model, max_len=max_seq_length)
|
| 296 |
+
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
|
| 297 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
|
| 298 |
+
self.fc_music = nn.Linear(d_model, output_dim)
|
| 299 |
+
self.fc_tags = nn.Linear(d_model, num_tags)
|
| 300 |
+
self.sigmoid = nn.Sigmoid()
|
| 301 |
+
self.dropout = nn.Dropout(0.1)
|
| 302 |
+
|
| 303 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 304 |
+
"""
|
| 305 |
+
src: [batch_size, seq_len, input_dim]
|
| 306 |
+
"""
|
| 307 |
+
src = self.input_linear(src) * np.sqrt(self.d_model) # [batch_size, seq_len, d_model]
|
| 308 |
+
src = self.positional_encoding(src) # [batch_size, seq_len, d_model]
|
| 309 |
+
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
|
| 310 |
+
memory = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) # [seq_len, batch_size, d_model]
|
| 311 |
+
memory = memory.transpose(0, 1) # [batch_size, seq_len, d_model]
|
| 312 |
+
memory = self.dropout(memory)
|
| 313 |
+
music_output = self.fc_music(memory) # [batch_size, seq_len, output_dim]
|
| 314 |
+
tag_probabilities = self.sigmoid(self.fc_tags(memory)) # [batch_size, seq_len, num_tags]
|
| 315 |
+
return music_output, tag_probabilities
|
| 316 |
+
|
| 317 |
+
class Discriminator(nn.Module):
|
| 318 |
+
def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward):
|
| 319 |
+
super(Discriminator, self).__init__()
|
| 320 |
+
self.input_linear = nn.Linear(input_dim, d_model)
|
| 321 |
+
self.positional_encoding = PositionalEncoding(d_model)
|
| 322 |
+
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
|
| 323 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
|
| 324 |
+
self.fc = nn.Linear(d_model, 1)
|
| 325 |
+
self.sigmoid = nn.Sigmoid()
|
| 326 |
+
self.dropout = nn.Dropout(0.1)
|
| 327 |
+
|
| 328 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 329 |
+
src = self.input_linear(src) * np.sqrt(self.input_linear.out_features)
|
| 330 |
+
src = self.positional_encoding(src)
|
| 331 |
+
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
|
| 332 |
+
output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
| 333 |
+
output = output.transpose(0, 1) # [batch_size, seq_len, d_model]
|
| 334 |
+
output = self.dropout(output)
|
| 335 |
+
# 取序列最后一个时间步作为判断依据,也可以选择取平均或其他方式
|
| 336 |
+
output = self.fc(output[:, -1, :])
|
| 337 |
+
output = self.sigmoid(output)
|
| 338 |
+
return output
|
| 339 |
+
|
| 340 |
+
class MidiDataset(Dataset):
|
| 341 |
+
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
|
| 342 |
+
self.max_length = max_length
|
| 343 |
+
self.dataset_path = dataset_path
|
| 344 |
+
self.evaluator = evaluator
|
| 345 |
+
# 检查数据集文件是否存在
|
| 346 |
+
if os.path.exists(self.dataset_path):
|
| 347 |
+
# 加载已预处理的数据集
|
| 348 |
+
print(f"从 '{self.dataset_path}' 加载数据集")
|
| 349 |
+
try:
|
| 350 |
+
saved_data = torch.load(self.dataset_path)
|
| 351 |
+
self.features = saved_data['features']
|
| 352 |
+
self.labels = saved_data['labels']
|
| 353 |
+
print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
|
| 354 |
+
except Exception as e:
|
| 355 |
+
print(f"加载数据集时出错: {e}")
|
| 356 |
+
self._process_midi_files(midi_files)
|
| 357 |
+
else:
|
| 358 |
+
# 处理 MIDI 文件并保存数据集
|
| 359 |
+
self._process_midi_files(midi_files)
|
| 360 |
+
|
| 361 |
+
def __len__(self):
|
| 362 |
+
return len(self.features)
|
| 363 |
+
|
| 364 |
+
def getAug(self, idx):
|
| 365 |
+
feature = self.features[idx] # [seq_len, input_dim]
|
| 366 |
+
label = self.labels[idx] # [num_tags]
|
| 367 |
+
# 应用数据增强
|
| 368 |
+
feature_aug, label_aug =self._augment_data(feature, label)
|
| 369 |
+
# 返回张量
|
| 370 |
+
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
|
| 371 |
+
def __getitem__(self, idx):
|
| 372 |
+
feature = self.features[idx] # [seq_len, input_dim]
|
| 373 |
+
label = self.labels[idx] # [num_tags]
|
| 374 |
+
# 应用数据增强
|
| 375 |
+
feature_aug, label_aug =feature, label
|
| 376 |
+
# 返回张量
|
| 377 |
+
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
|
| 378 |
+
|
| 379 |
+
def _process_midi_files(self, midi_files):
|
| 380 |
+
print("处理 MIDI 文件以创建数据集...")
|
| 381 |
+
features_list = []
|
| 382 |
+
labels_list = []
|
| 383 |
+
for midi_file in midi_files:
|
| 384 |
+
try:
|
| 385 |
+
stream = music21.converter.parse(midi_file)
|
| 386 |
+
# 将音轨转换为特征
|
| 387 |
+
features = self.midi_to_features(stream)
|
| 388 |
+
if len(features) < self.max_length:
|
| 389 |
+
# 跳过长度不足的样本
|
| 390 |
+
continue
|
| 391 |
+
else:
|
| 392 |
+
# 将特征分割成长度为 max_length 的片段
|
| 393 |
+
num_segments = len(features) // self.max_length
|
| 394 |
+
for i in range(num_segments):
|
| 395 |
+
segment = features[i*self.max_length : (i+1)*self.max_length]
|
| 396 |
+
if len(segment) < self.max_length:
|
| 397 |
+
continue # 跳过不完整的片段
|
| 398 |
+
# 使用评估器为每个片段分配标签
|
| 399 |
+
tags = self.evaluator.evaluate_tags_from_features(segment)
|
| 400 |
+
# 二值化标签
|
| 401 |
+
tag_binarized = self.evaluator.mlb.transform([tags])[0]
|
| 402 |
+
features_list.append(segment)
|
| 403 |
+
labels_list.append(tag_binarized)
|
| 404 |
+
except Exception as e:
|
| 405 |
+
print(f"处理 {midi_file} 时出错: {e}")
|
| 406 |
+
self.features = features_list
|
| 407 |
+
self.labels = labels_list
|
| 408 |
+
# 保存数据集
|
| 409 |
+
try:
|
| 410 |
+
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
|
| 411 |
+
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
|
| 412 |
+
except Exception as e:
|
| 413 |
+
print(f"保存数据集时出错: {e}")
|
| 414 |
+
|
| 415 |
+
def midi_to_features(self, stream) -> np.ndarray:
|
| 416 |
+
"""
|
| 417 |
+
将 music21 流对象转换为特征序列。
|
| 418 |
+
"""
|
| 419 |
+
features = []
|
| 420 |
+
for note in stream.flat.notesAndRests:
|
| 421 |
+
if isinstance(note, music21.note.Note):
|
| 422 |
+
pitch = note.pitch.midi
|
| 423 |
+
duration = note.quarterLength
|
| 424 |
+
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
|
| 425 |
+
elif isinstance(note, music21.note.Rest):
|
| 426 |
+
pitch = 0 # 休止符音高设为 0
|
| 427 |
+
duration = note.quarterLength
|
| 428 |
+
volume = 0
|
| 429 |
+
else:
|
| 430 |
+
continue
|
| 431 |
+
features.append([pitch, duration, volume])
|
| 432 |
+
return np.array(features, dtype=np.float32)
|
| 433 |
+
|
| 434 |
+
def _augment_data(self, feature, label):
|
| 435 |
+
# 实现数据增强:随机抽取、拼接、动态和快慢变化
|
| 436 |
+
# 例如,随机调整动态和节奏
|
| 437 |
+
feature_aug = np.copy(feature)
|
| 438 |
+
label_aug = np.copy(label)
|
| 439 |
+
# 随机调整音量(动态)
|
| 440 |
+
volume_change = np.random.uniform(0.8, 1.2)
|
| 441 |
+
feature_aug[:, 2] *= volume_change
|
| 442 |
+
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
|
| 443 |
+
# 随机调整时值(节奏变化)
|
| 444 |
+
duration_change = np.random.uniform(0.9, 1.1)
|
| 445 |
+
feature_aug[:, 1] *= duration_change
|
| 446 |
+
# 根据变化调整标签
|
| 447 |
+
# 例如,如果节奏变化显著,调整 'tempo' 标签
|
| 448 |
+
if duration_change > 1.05:
|
| 449 |
+
# 更快的节奏
|
| 450 |
+
tempo_tags = ['Fast']
|
| 451 |
+
elif duration_change < 0.95:
|
| 452 |
+
# 更慢的节奏
|
| 453 |
+
tempo_tags = ['Slow']
|
| 454 |
+
else:
|
| 455 |
+
tempo_tags = ['Medium']
|
| 456 |
+
# 更新 'tempo' 标签
|
| 457 |
+
for tempo in ['Slow', 'Medium', 'Fast']:
|
| 458 |
+
label_aug[self.evaluator.all_tags.index(tempo)] = 0
|
| 459 |
+
tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
|
| 460 |
+
label_aug[tempo_index] = 1
|
| 461 |
+
return feature_aug, label_aug
|
| 462 |
+
|
| 463 |
+
class MidiDatasetAug(Dataset):
|
| 464 |
+
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
|
| 465 |
+
self.max_length = max_length
|
| 466 |
+
self.dataset_path = dataset_path
|
| 467 |
+
self.evaluator = evaluator
|
| 468 |
+
# 检查数据集文件是否存在
|
| 469 |
+
if os.path.exists(self.dataset_path):
|
| 470 |
+
# 加载已预处理的数据集
|
| 471 |
+
print(f"从 '{self.dataset_path}' 加载数据集")
|
| 472 |
+
try:
|
| 473 |
+
saved_data = torch.load(self.dataset_path)
|
| 474 |
+
self.features = saved_data['features']
|
| 475 |
+
self.labels = saved_data['labels']
|
| 476 |
+
print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
|
| 477 |
+
except Exception as e:
|
| 478 |
+
print(f"加载数据集时出错: {e}")
|
| 479 |
+
self._process_midi_files(midi_files)
|
| 480 |
+
else:
|
| 481 |
+
# 处理 MIDI 文件并保存数据集
|
| 482 |
+
self._process_midi_files(midi_files)
|
| 483 |
+
|
| 484 |
+
def __len__(self):
|
| 485 |
+
return len(self.features)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def __getitem__(self, idx):
|
| 489 |
+
feature = self.features[idx] # [seq_len, input_dim]
|
| 490 |
+
label = self.labels[idx] # [num_tags]
|
| 491 |
+
# 应用数据增强
|
| 492 |
+
feature_aug, label_aug =self._augment_data(feature, label)
|
| 493 |
+
# 返回张量
|
| 494 |
+
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
|
| 495 |
+
torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
|
| 496 |
+
|
| 497 |
+
def _process_midi_files(self, midi_files):
|
| 498 |
+
print("处理 MIDI 文件以创建数据集...")
|
| 499 |
+
features_list = []
|
| 500 |
+
labels_list = []
|
| 501 |
+
for midi_file in midi_files:
|
| 502 |
+
try:
|
| 503 |
+
stream = music21.converter.parse(midi_file)
|
| 504 |
+
# 将音轨转换为特征
|
| 505 |
+
features = self.midi_to_features(stream)
|
| 506 |
+
if len(features) < self.max_length:
|
| 507 |
+
# 跳过长度不足的样本
|
| 508 |
+
continue
|
| 509 |
+
else:
|
| 510 |
+
# 将特征分割成长度为 max_length 的片段
|
| 511 |
+
num_segments = len(features) // self.max_length
|
| 512 |
+
for i in range(num_segments):
|
| 513 |
+
segment = features[i*self.max_length : (i+1)*self.max_length]
|
| 514 |
+
if len(segment) < self.max_length:
|
| 515 |
+
continue # 跳过不完整的片段
|
| 516 |
+
# 使用评估器为每个片段分配标签
|
| 517 |
+
tags = self.evaluator.evaluate_tags_from_features(segment)
|
| 518 |
+
# 二值化标签
|
| 519 |
+
tag_binarized = self.evaluator.mlb.transform([tags])[0]
|
| 520 |
+
features_list.append(segment)
|
| 521 |
+
labels_list.append(tag_binarized)
|
| 522 |
+
except Exception as e:
|
| 523 |
+
print(f"处理 {midi_file} 时出错: {e}")
|
| 524 |
+
self.features = features_list
|
| 525 |
+
self.labels = labels_list
|
| 526 |
+
# 保存数据集
|
| 527 |
+
try:
|
| 528 |
+
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
|
| 529 |
+
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
|
| 530 |
+
except Exception as e:
|
| 531 |
+
print(f"保存数据集时出错: {e}")
|
| 532 |
+
|
| 533 |
+
def midi_to_features(self, stream) -> np.ndarray:
|
| 534 |
+
"""
|
| 535 |
+
将 music21 流对象转换为特征序列。
|
| 536 |
+
"""
|
| 537 |
+
features = []
|
| 538 |
+
for note in stream.flat.notesAndRests:
|
| 539 |
+
if isinstance(note, music21.note.Note):
|
| 540 |
+
pitch = note.pitch.midi
|
| 541 |
+
duration = note.quarterLength
|
| 542 |
+
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
|
| 543 |
+
elif isinstance(note, music21.note.Rest):
|
| 544 |
+
pitch = 0 # 休止符音高设为 0
|
| 545 |
+
duration = note.quarterLength
|
| 546 |
+
volume = 0
|
| 547 |
+
else:
|
| 548 |
+
continue
|
| 549 |
+
features.append([pitch, duration, volume])
|
| 550 |
+
return np.array(features, dtype=np.float32)
|
| 551 |
+
|
| 552 |
+
def _augment_data(self, feature, label):
|
| 553 |
+
# 实现数据增强:随机抽取、拼接、动态和快慢变化
|
| 554 |
+
# 例如,随机调整动态和节奏
|
| 555 |
+
feature_aug = np.copy(feature)
|
| 556 |
+
label_aug = np.copy(label)
|
| 557 |
+
# 随机调整音量(动态)
|
| 558 |
+
volume_change = np.random.uniform(0.8, 1.2)
|
| 559 |
+
feature_aug[:, 2] *= volume_change
|
| 560 |
+
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
|
| 561 |
+
# 随机调整时值(节奏变化)
|
| 562 |
+
duration_change = np.random.uniform(0.9, 1.1)
|
| 563 |
+
feature_aug[:, 1] *= duration_change
|
| 564 |
+
# 根据变化调整标签
|
| 565 |
+
# 例如,如果节奏变化显著,调整 'tempo' 标签
|
| 566 |
+
if duration_change > 1.05:
|
| 567 |
+
# 更快的节奏
|
| 568 |
+
tempo_tags = ['Fast']
|
| 569 |
+
elif duration_change < 0.95:
|
| 570 |
+
# 更慢的节奏
|
| 571 |
+
tempo_tags = ['Slow']
|
| 572 |
+
else:
|
| 573 |
+
tempo_tags = ['Medium']
|
| 574 |
+
# 更新 'tempo' 标签
|
| 575 |
+
for tempo in ['Slow', 'Medium', 'Fast']:
|
| 576 |
+
label_aug[self.evaluator.all_tags.index(tempo)] = 0
|
| 577 |
+
tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
|
| 578 |
+
label_aug[tempo_index] = 1
|
| 579 |
+
return feature_aug, label_aug
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class RandomDataset(Dataset):
|
| 584 |
+
def __init__(self, size: int, max_length: int):
|
| 585 |
+
"""
|
| 586 |
+
随机生成数据集。
|
| 587 |
+
|
| 588 |
+
参数:
|
| 589 |
+
size (int): 数据集大小。
|
| 590 |
+
max_length (int): 每个样本的序列长度。
|
| 591 |
+
"""
|
| 592 |
+
self.size = size
|
| 593 |
+
self.max_length = max_length
|
| 594 |
+
|
| 595 |
+
def __len__(self):
|
| 596 |
+
return self.size
|
| 597 |
+
|
| 598 |
+
def __getitem__(self, idx):
|
| 599 |
+
# 随机音高范围在21(A0)到108(C8)之间
|
| 600 |
+
pitch = np.random.randint(21, 109, size=(self.max_length, 1)).astype(np.float32)
|
| 601 |
+
# 随机选择可接受的时值
|
| 602 |
+
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
|
| 603 |
+
duration = np.random.choice(acceptable_durations, size=(self.max_length, 1)).astype(np.float32)
|
| 604 |
+
# 随机音量在60到100之间
|
| 605 |
+
volume = np.random.randint(40, 70, size=(self.max_length, 1)).astype(np.float32)
|
| 606 |
+
features = np.concatenate([pitch, duration, volume], axis=-1) # [max_length, 3]
|
| 607 |
+
return torch.tensor(features, dtype=torch.float32)
|
| 608 |
+
|
| 609 |
+
class MusicGenerator:
|
| 610 |
+
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
|
| 611 |
+
self.model = model.to(device)
|
| 612 |
+
self.evaluator = evaluator
|
| 613 |
+
self.device = device
|
| 614 |
+
self.model_path = model_path
|
| 615 |
+
self.optimizer = optimizer
|
| 616 |
+
self.optimizer_path = optimizer_path
|
| 617 |
+
self.writer = writer
|
| 618 |
+
self._load_model()
|
| 619 |
+
# 定义归一化和反归一化参数
|
| 620 |
+
self.min_pitch = 21
|
| 621 |
+
self.max_pitch = 108
|
| 622 |
+
self.min_duration = 0.15
|
| 623 |
+
self.max_duration = 1.5
|
| 624 |
+
self.min_volume = 40
|
| 625 |
+
self.max_volume = 85
|
| 626 |
+
|
| 627 |
+
def _load_model(self):
|
| 628 |
+
"""自动载入已存在的模型权重,如果存在的话。"""
|
| 629 |
+
if os.path.exists(self.model_path):
|
| 630 |
+
try:
|
| 631 |
+
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
| 632 |
+
self.model.to(self.device)
|
| 633 |
+
self.model.eval()
|
| 634 |
+
print(f"已成功载入模型权重从 '{self.model_path}'。")
|
| 635 |
+
except Exception as e:
|
| 636 |
+
print(f"载入模型权重时出错: {e},将初始化新模型。")
|
| 637 |
+
else:
|
| 638 |
+
print("未找到已保存的模型,将初始化新模型。")
|
| 639 |
+
|
| 640 |
+
# 加载优化器状态
|
| 641 |
+
if self.optimizer and self.optimizer_path and os.path.exists(self.optimizer_path):
|
| 642 |
+
try:
|
| 643 |
+
self.optimizer.load_state_dict(torch.load(self.optimizer_path, map_location=self.device))
|
| 644 |
+
print(f"已成功载入优化器状态从 '{self.optimizer_path}'。")
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"载入优化器状态时出错: {e},将初始化新优化器。")
|
| 647 |
+
else:
|
| 648 |
+
if self.optimizer and self.optimizer_path:
|
| 649 |
+
print("未找到已保存的优化器状态,将初始化新优化器。")
|
| 650 |
+
|
| 651 |
+
def save_model(self, epoch: int, loss: float):
|
| 652 |
+
"""保存当前模型的权重和优化器状态。"""
|
| 653 |
+
try:
|
| 654 |
+
torch.save(self.model.state_dict(), self.model_path, _use_new_zipfile_serialization=False)
|
| 655 |
+
if self.optimizer and self.optimizer_path:
|
| 656 |
+
torch.save(self.optimizer.state_dict(), self.optimizer_path, _use_new_zipfile_serialization=False)
|
| 657 |
+
print(f"模型和优化器已保存至 '{self.model_path}' 和 '{self.optimizer_path}'。")
|
| 658 |
+
if self.writer:
|
| 659 |
+
self.writer.add_scalar('Loss/Save', loss, epoch)
|
| 660 |
+
except Exception as e:
|
| 661 |
+
print(f"保存模型或优化器时出错: {e}")
|
| 662 |
+
|
| 663 |
+
def train_epoch(self, dataloader: DataLoader, optimizer, criterion_music, criterion_tags, epoch: int):
|
| 664 |
+
"""
|
| 665 |
+
训练一个 epoch。
|
| 666 |
+
"""
|
| 667 |
+
self.model.train()
|
| 668 |
+
total_loss = 0.0
|
| 669 |
+
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
|
| 670 |
+
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
|
| 671 |
+
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
|
| 672 |
+
inputs = batch_features[:, :-1, :] # [batch_size, seq_len-1, input_dim]
|
| 673 |
+
targets = batch_features[:, -1, :] # [batch_size, input_dim]
|
| 674 |
+
|
| 675 |
+
optimizer.zero_grad()
|
| 676 |
+
music_output, tag_probabilities = self.model(inputs) # 音乐输出: [batch, seq_len-1, output_dim]
|
| 677 |
+
|
| 678 |
+
# 只对最后一个时间步的输出进行损失计算
|
| 679 |
+
loss_music = criterion_music(music_output[:, -1, :], targets)
|
| 680 |
+
|
| 681 |
+
# 使用数据集中的标签
|
| 682 |
+
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
|
| 683 |
+
|
| 684 |
+
# 总损失
|
| 685 |
+
loss = loss_music + loss_tags
|
| 686 |
+
loss.backward()
|
| 687 |
+
|
| 688 |
+
# 梯度裁剪
|
| 689 |
+
clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 690 |
+
|
| 691 |
+
optimizer.step()
|
| 692 |
+
|
| 693 |
+
total_loss += loss.item()
|
| 694 |
+
if self.writer:
|
| 695 |
+
self.writer.add_scalar('Loss/Train', loss.item(), epoch * len(dataloader) + batch_idx)
|
| 696 |
+
|
| 697 |
+
avg_loss = total_loss / len(dataloader)
|
| 698 |
+
print(f"Epoch {epoch} 平均损失: {avg_loss:.4f}")
|
| 699 |
+
return avg_loss
|
| 700 |
+
|
| 701 |
+
def train_epoch_gan(self, dataloader, optimizer_generator, optimizer_discriminator, criterion_music, criterion_tags, criterion_discriminator, discriminator, epoch):
|
| 702 |
+
"""
|
| 703 |
+
使用对抗训练的方法训练一个 epoch。
|
| 704 |
+
"""
|
| 705 |
+
self.model.train()
|
| 706 |
+
discriminator.train()
|
| 707 |
+
total_loss = 0.0
|
| 708 |
+
|
| 709 |
+
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
|
| 710 |
+
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
|
| 711 |
+
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
|
| 712 |
+
batch_size = batch_features.size(0)
|
| 713 |
+
seq_len = batch_features.size(1)
|
| 714 |
+
# ---------------------
|
| 715 |
+
# 训练判别器
|
| 716 |
+
# ---------------------
|
| 717 |
+
|
| 718 |
+
# 使用真实数据
|
| 719 |
+
real_data = batch_features # [batch_size, seq_len, input_dim]
|
| 720 |
+
real_labels = torch.ones(batch_size, 1).to(self.device)
|
| 721 |
+
|
| 722 |
+
# 使用生成器生成假数据
|
| 723 |
+
noise = torch.rand(batch_size, seq_len, 3).to(self.device) # 随机噪声在 [0,1],与归一化后的特征一致
|
| 724 |
+
generated_features = torch.zeros_like(batch_features).to(self.device)
|
| 725 |
+
for i in range(seq_len):
|
| 726 |
+
input_noise = noise[:, :i+1, :]
|
| 727 |
+
fake_data, _ = self.model(input_noise)
|
| 728 |
+
generated_features[:, i, :] = fake_data[:, -1, :]
|
| 729 |
+
|
| 730 |
+
fake_data = generated_features.detach() # [batch_size, seq_len, input_dim]
|
| 731 |
+
fake_labels = torch.zeros(batch_size, 1).to(self.device)
|
| 732 |
+
|
| 733 |
+
# 计算判别器在真实数据上的损失
|
| 734 |
+
optimizer_discriminator.zero_grad()
|
| 735 |
+
output_real = discriminator(real_data)
|
| 736 |
+
loss_real = criterion_discriminator(output_real, real_labels)
|
| 737 |
+
|
| 738 |
+
# 计算判别器在假数据上的损失
|
| 739 |
+
output_fake = discriminator(fake_data)
|
| 740 |
+
loss_fake = criterion_discriminator(output_fake, fake_labels)
|
| 741 |
+
|
| 742 |
+
# 总损失并反向传播
|
| 743 |
+
loss_discriminator = (loss_real + loss_fake) / 2
|
| 744 |
+
loss_discriminator.backward()
|
| 745 |
+
optimizer_discriminator.step()
|
| 746 |
+
|
| 747 |
+
# ---------------------
|
| 748 |
+
# 训练生成器
|
| 749 |
+
# ---------------------
|
| 750 |
+
|
| 751 |
+
optimizer_generator.zero_grad()
|
| 752 |
+
# 生成假数据并计算生成器的损失,目标是让判别器相信这些数据是真实的
|
| 753 |
+
output_fake_for_generator = discriminator(fake_data)
|
| 754 |
+
loss_generator_adv = criterion_discriminator(output_fake_for_generator, real_labels) # 生成器的对抗损失
|
| 755 |
+
|
| 756 |
+
# 计算生成器的音乐特征和标签损失
|
| 757 |
+
music_output, tag_probabilities = self.model(noise)
|
| 758 |
+
targets = batch_features[:, -1, :] # 真实的最后一个特征
|
| 759 |
+
loss_music = criterion_music(music_output[:, -1, :], targets)
|
| 760 |
+
|
| 761 |
+
# 使用数据集中的标签
|
| 762 |
+
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
|
| 763 |
+
|
| 764 |
+
# 总损失
|
| 765 |
+
loss_generator = loss_generator_adv + loss_music + loss_tags
|
| 766 |
+
loss_generator.backward()
|
| 767 |
+
|
| 768 |
+
# 梯度裁剪
|
| 769 |
+
clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 770 |
+
|
| 771 |
+
optimizer_generator.step()
|
| 772 |
+
|
| 773 |
+
total_loss += loss_generator.item()
|
| 774 |
+
if self.writer:
|
| 775 |
+
#self.writer.add_scalar('Loss/Generator', loss_generator.item(), epoch * len(dataloader) + batch_idx)
|
| 776 |
+
#self.writer.add_scalar('Loss/Discriminator', loss_discriminator.item(), epoch * len(dataloader) + batch_idx)
|
| 777 |
+
pass
|
| 778 |
+
|
| 779 |
+
avg_loss = total_loss / len(dataloader)
|
| 780 |
+
print(f"Epoch {epoch} 平均生成器损失: {avg_loss:.4f}")
|
| 781 |
+
return avg_loss
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def generate_music(self, tag_conditions: dict={
|
| 785 |
+
'emotions': 'Neutral',
|
| 786 |
+
'genres': 'Classical',
|
| 787 |
+
'tempo': 'Medium',
|
| 788 |
+
'instrumentation': 'Piano',
|
| 789 |
+
'harmony': 'Simple',
|
| 790 |
+
'dynamics': 'Dynamic',
|
| 791 |
+
'rhythm': 'Simple' # 或 'Complex'
|
| 792 |
+
}, max_length=100, temperature=1.0) -> music21.stream.Stream:
|
| 793 |
+
"""
|
| 794 |
+
根据标签生成音乐。
|
| 795 |
+
"""
|
| 796 |
+
self.model.eval()
|
| 797 |
+
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
|
| 798 |
+
generated_features = []
|
| 799 |
+
|
| 800 |
+
with torch.no_grad():
|
| 801 |
+
# 随机选择一个调性
|
| 802 |
+
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
|
| 803 |
+
scale_notes = get_scale_notes(key_str)
|
| 804 |
+
|
| 805 |
+
# 初始输入(随机特征)
|
| 806 |
+
input_feature = torch.zeros(1, 1, 3).to(self.device) # [batch_size=1, seq_len=1, input_dim=3]
|
| 807 |
+
|
| 808 |
+
for _ in range(max_length):
|
| 809 |
+
music_output, tag_probabilities = self.model(input_feature) # [1, seq_len, 3] and [1, seq_len, num_tags]
|
| 810 |
+
music_output_np = music_output.cpu().numpy()[0, -1]
|
| 811 |
+
|
| 812 |
+
# 应用温度控制
|
| 813 |
+
music_output_np = music_output_np / temperature
|
| 814 |
+
|
| 815 |
+
# 使用概率分布进行采样
|
| 816 |
+
pitch = int(round(music_output_np[0]))
|
| 817 |
+
duration = music_output_np[1]
|
| 818 |
+
volume = int(round(music_output_np[2]))
|
| 819 |
+
|
| 820 |
+
# 增加随机变动
|
| 821 |
+
pitch += int(np.random.uniform(-2, 2))
|
| 822 |
+
pitch = max(21, min(108, pitch)) # 限制在钢琴键范围内
|
| 823 |
+
# 将音高映射到最近的音阶音符
|
| 824 |
+
if pitch not in scale_notes:
|
| 825 |
+
pitch = min(scale_notes, key=lambda x: abs(x - pitch))
|
| 826 |
+
duration += np.random.uniform(-0.1, 0.1)
|
| 827 |
+
try:
|
| 828 |
+
duration = min(acceptable_durations, key=lambda x: abs(x - duration))
|
| 829 |
+
except ValueError:
|
| 830 |
+
duration = 1.0 # 默认时值
|
| 831 |
+
volume += int(np.random.uniform(-10, 10))
|
| 832 |
+
volume = max(70, min(100, volume)) # 限制音量范围
|
| 833 |
+
|
| 834 |
+
# 保存特征
|
| 835 |
+
generated_features.append([pitch, duration, volume])
|
| 836 |
+
|
| 837 |
+
# 更新输入
|
| 838 |
+
next_input = torch.tensor([[pitch, duration, volume]], dtype=torch.float32).to(self.device).unsqueeze(0) # [1, 1, 3]
|
| 839 |
+
input_feature = torch.cat((input_feature, next_input), dim=1) # 增加序列长度
|
| 840 |
+
|
| 841 |
+
# 转换为 numpy 数组
|
| 842 |
+
generated_features_array = np.array(generated_features, dtype=np.float32)
|
| 843 |
+
generated_stream = composer_from_features(generated_features_array, key_str)
|
| 844 |
+
|
| 845 |
+
# 评估标签
|
| 846 |
+
tag_scores = self.evaluator.evaluate_tags(generated_stream)
|
| 847 |
+
print("生成的音乐标签:", tag_scores)
|
| 848 |
+
|
| 849 |
+
# 根据情感进行判断并保存
|
| 850 |
+
high_score_emotions = ['Happy', 'Peaceful']
|
| 851 |
+
if tag_scores.get('emotions') in high_score_emotions:
|
| 852 |
+
# 将生成的 MIDI 转换为 WAV
|
| 853 |
+
midi_filename = f'high_score_{int(time.time())}.mid'
|
| 854 |
+
generated_stream.write('midi', fp=os.path.join(Gbase, midi_filename))
|
| 855 |
+
wav_file = self.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f'high_score_{int(time.time())}.wav'))
|
| 856 |
+
print(f"高评分音乐已保存为 WAV 文件: '{wav_file}'")
|
| 857 |
+
|
| 858 |
+
return generated_stream
|
| 859 |
+
|
| 860 |
+
def addMusicToVideo(self, videoPath, tagConditions={
|
| 861 |
+
'emotions': 'Neutral',
|
| 862 |
+
'genres': 'Classical',
|
| 863 |
+
'tempo': 'Medium',
|
| 864 |
+
'instrumentation': 'Piano',
|
| 865 |
+
'harmony': 'Simple',
|
| 866 |
+
'dynamics': 'Dynamic',
|
| 867 |
+
'rhythm': 'Simple' # 或 'Complex'
|
| 868 |
+
}, outputPath=None):
|
| 869 |
+
"""
|
| 870 |
+
根据指定的标签条件生成音乐,并将其附加到视频中,确保音乐的长度与视频一致。
|
| 871 |
+
|
| 872 |
+
参数:
|
| 873 |
+
videoPath (str): 输入视频的路径。
|
| 874 |
+
tagConditions (dict): 用于生成音乐的标签条件。
|
| 875 |
+
outputPath (str, optional): 输出视频的路径。如果未指定,将在原路径基础上添加 '_with_music'。
|
| 876 |
+
|
| 877 |
+
返回:
|
| 878 |
+
str: 输出的视频路径。
|
| 879 |
+
"""
|
| 880 |
+
# 1. 获取视频时长
|
| 881 |
+
try:
|
| 882 |
+
video = VideoFileClip(videoPath)
|
| 883 |
+
duration = video.duration
|
| 884 |
+
print(f"视频时长: {duration} 秒。")
|
| 885 |
+
except Exception as e:
|
| 886 |
+
print(f"无法载入视频: {e}")
|
| 887 |
+
return None
|
| 888 |
+
if not outputPath:
|
| 889 |
+
base, ext = os.path.splitext(videoPath)
|
| 890 |
+
outputPath = f"{base}_with_music{ext}"
|
| 891 |
+
if os.path.exists (outputPath):return outputPath
|
| 892 |
+
# 2. 初始化音频拼接
|
| 893 |
+
combined_audio = AudioSegment.silent(duration=0) # 初始化为空音频
|
| 894 |
+
total_generated_duration = 0 # 总生成时长(毫秒)
|
| 895 |
+
chunk_duration_seconds = 10 # 每次生成音讯的预估时长(秒),根据需要调整
|
| 896 |
+
crossfade_duration = 500 # 淡入淡出持续时间(毫秒)
|
| 897 |
+
|
| 898 |
+
# 3. 逐段生成音频
|
| 899 |
+
print("逐段生成音乐中...")
|
| 900 |
+
while total_generated_duration < duration * 1000: # pydub 使用毫秒
|
| 901 |
+
# 根据剩余时长生成音乐,确保不生成过多
|
| 902 |
+
remaining_duration_ms = duration * 1000 - total_generated_duration
|
| 903 |
+
remaining_duration_seconds = remaining_duration_ms / 1000.0
|
| 904 |
+
current_chunk_length = min(chunk_duration_seconds, remaining_duration_seconds)
|
| 905 |
+
|
| 906 |
+
# 计算所需的音符数量,假设每个音符平均约0.5秒
|
| 907 |
+
estimated_max_length = int(current_chunk_length / 0.5) * 2 # 调整因子根据实际情况
|
| 908 |
+
|
| 909 |
+
# 生成音乐流
|
| 910 |
+
generated_stream = self.generate_music(max_length=100)
|
| 911 |
+
|
| 912 |
+
# 转换为 WAV 文件
|
| 913 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as wav_temp:
|
| 914 |
+
wav_filename = wav_temp.name
|
| 915 |
+
wav_path = self.custom_midi_to_wav(generated_stream, wav_filename)
|
| 916 |
+
print(f"生成的 WAV 已保存为 '{wav_path}'。")
|
| 917 |
+
|
| 918 |
+
# 加载生成的音频
|
| 919 |
+
try:
|
| 920 |
+
generated_audio = AudioSegment.from_wav(wav_path)
|
| 921 |
+
except Exception as e:
|
| 922 |
+
print(f"加载生成的音频时出错: {e}")
|
| 923 |
+
os.remove(wav_path)
|
| 924 |
+
#break
|
| 925 |
+
|
| 926 |
+
# 拼接音频,应用淡入淡出效果
|
| 927 |
+
if len(combined_audio) == 0:
|
| 928 |
+
# 第一段音频,仅应用淡入
|
| 929 |
+
generated_audio = generated_audio.fade_in(crossfade_duration)
|
| 930 |
+
combined_audio += generated_audio
|
| 931 |
+
else:
|
| 932 |
+
# 之后的音频段,应用淡出和淡入,并设置 crossfade
|
| 933 |
+
generated_audio = generated_audio.fade_in(crossfade_duration)
|
| 934 |
+
combined_audio = combined_audio.append(generated_audio, crossfade=crossfade_duration)
|
| 935 |
+
|
| 936 |
+
total_generated_duration = len(combined_audio)
|
| 937 |
+
|
| 938 |
+
# 删除临时 WAV 文件
|
| 939 |
+
try:
|
| 940 |
+
os.remove(wav_path)
|
| 941 |
+
print(f"已删除临时 WAV 文件 '{wav_path}'。")
|
| 942 |
+
except Exception as e:
|
| 943 |
+
print(f"删除临时 WAV 文件时出错: {e}")
|
| 944 |
+
|
| 945 |
+
# 4. 剪切音频以匹配视频时长
|
| 946 |
+
final_audio = combined_audio[:int(duration * 1000)] # pydub 使用毫秒为单位
|
| 947 |
+
final_wav_path = tempfile.mktemp(suffix='.wav')
|
| 948 |
+
final_audio.export(final_wav_path, format="wav")
|
| 949 |
+
print(f"最终剪切后的 WAV 已保存为 '{final_wav_path}'。")
|
| 950 |
+
|
| 951 |
+
# 5. 定义输出视频路径
|
| 952 |
+
if not outputPath:
|
| 953 |
+
base, ext = os.path.splitext(videoPath)
|
| 954 |
+
outputPath = f"{base}_with_music{ext}"
|
| 955 |
+
|
| 956 |
+
# 6. 使用 moviepy 将音频与视频结合
|
| 957 |
+
try:
|
| 958 |
+
# 载入视频和音频
|
| 959 |
+
video_clip = VideoFileClip(videoPath)
|
| 960 |
+
audio_clip = AudioFileClip(final_wav_path)
|
| 961 |
+
|
| 962 |
+
# 设置音频,确保音频长度与视频一致
|
| 963 |
+
audio_clip = audio_clip.set_duration(video_clip.duration)
|
| 964 |
+
|
| 965 |
+
# 将音频附加到视频
|
| 966 |
+
video_with_audio = video_clip.set_audio(audio_clip)
|
| 967 |
+
|
| 968 |
+
# 输出最终视频
|
| 969 |
+
video_with_audio.write_videofile(outputPath, codec='libx264', audio_codec='aac', verbose=False, logger=None)
|
| 970 |
+
print(f"输出视频已保存为 '{outputPath}'。")
|
| 971 |
+
except Exception as e:
|
| 972 |
+
print(f"结合视频和音频时出错: {e}")
|
| 973 |
+
return None
|
| 974 |
+
finally:
|
| 975 |
+
# 清理 moviepy 生成的资源
|
| 976 |
+
if 'video_clip' in locals():
|
| 977 |
+
video_clip.close()
|
| 978 |
+
if 'audio_clip' in locals():
|
| 979 |
+
audio_clip.close()
|
| 980 |
+
if 'video_with_audio' in locals():
|
| 981 |
+
video_with_audio.close()
|
| 982 |
+
|
| 983 |
+
# 7. 清理临时文件
|
| 984 |
+
try:
|
| 985 |
+
os.remove(final_wav_path)
|
| 986 |
+
print("最终临时 WAV 文件已删除。")
|
| 987 |
+
except Exception as e:
|
| 988 |
+
print(f"删除最终临时 WAV 文件时出错: {e}")
|
| 989 |
+
|
| 990 |
+
return outputPath
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def custom_midi_to_wav(self, stream: music21.stream.Stream, wav_filename: str, sample_rate=44100) -> str:
|
| 996 |
+
"""
|
| 997 |
+
自定义的 MIDI 到 WAV 转换函数,使用数学公式生成高质量的音频。
|
| 998 |
+
改进后:声音更加悦耳,符合音符、音阶、乐器的基本要求。
|
| 999 |
+
"""
|
| 1000 |
+
import math
|
| 1001 |
+
|
| 1002 |
+
# 合成参数
|
| 1003 |
+
envelope_attack = 0.01 # 攻击时间
|
| 1004 |
+
envelope_decay = 0.1 # 衰减时间
|
| 1005 |
+
envelope_sustain = 0.8 # 持续水平
|
| 1006 |
+
envelope_release = 0.2 # 释放时间
|
| 1007 |
+
|
| 1008 |
+
# 获取节奏信息
|
| 1009 |
+
metronome_marks = list(stream.metronomeMarkBoundaries())
|
| 1010 |
+
bpm = 120 # 默认 BPM
|
| 1011 |
+
if metronome_marks:
|
| 1012 |
+
# 检查是否存在 MetronomeMark 对象
|
| 1013 |
+
for mark in metronome_marks:
|
| 1014 |
+
if isinstance(mark[2], music21.tempo.MetronomeMark) and mark[2].number:
|
| 1015 |
+
bpm = mark[2].number
|
| 1016 |
+
break
|
| 1017 |
+
|
| 1018 |
+
# 生成时间轴
|
| 1019 |
+
notes = list(stream.flat.getElementsByClass(['Note', 'Chord', 'Rest']))
|
| 1020 |
+
if not notes:
|
| 1021 |
+
print("没有音符可生成音频。")
|
| 1022 |
+
return ""
|
| 1023 |
+
|
| 1024 |
+
# 计算整体时长
|
| 1025 |
+
total_duration = stream.duration.quarterLength * 60 / bpm
|
| 1026 |
+
total_samples = int(total_duration * sample_rate) + 1
|
| 1027 |
+
audio = np.zeros(total_samples)
|
| 1028 |
+
|
| 1029 |
+
current_time = 0
|
| 1030 |
+
|
| 1031 |
+
# 定义乐器的谐波系数,模拟钢琴的谐波
|
| 1032 |
+
harmonic_coeffs = [1.0, 0.5, 0.25, 0.1, 0.05]
|
| 1033 |
+
|
| 1034 |
+
for element in notes:
|
| 1035 |
+
if isinstance(element, music21.note.Rest):
|
| 1036 |
+
# 休止符,更新当前时间
|
| 1037 |
+
duration = element.quarterLength * 60 / bpm # 秒
|
| 1038 |
+
current_time += duration
|
| 1039 |
+
continue
|
| 1040 |
+
|
| 1041 |
+
elif isinstance(element, music21.note.Note):
|
| 1042 |
+
frequencies = [element.pitch.frequency]
|
| 1043 |
+
elif isinstance(element, music21.chord.Chord):
|
| 1044 |
+
frequencies = [p.frequency for p in element.pitches]
|
| 1045 |
+
else:
|
| 1046 |
+
continue
|
| 1047 |
+
|
| 1048 |
+
duration = element.quarterLength * 60 / bpm # 秒
|
| 1049 |
+
# 音量固定为70%
|
| 1050 |
+
volume = 0.6
|
| 1051 |
+
|
| 1052 |
+
# 生成波形时间轴
|
| 1053 |
+
t = np.linspace(0, duration, int(duration * sample_rate), False)
|
| 1054 |
+
|
| 1055 |
+
waveform = np.zeros_like(t)
|
| 1056 |
+
for freq in frequencies:
|
| 1057 |
+
note_waveform = np.zeros_like(t)
|
| 1058 |
+
for idx, coeff in enumerate(harmonic_coeffs):
|
| 1059 |
+
harmonic_freq = freq * (idx + 1)
|
| 1060 |
+
note_waveform += coeff * np.sin(2 * np.pi * harmonic_freq * t)
|
| 1061 |
+
waveform += note_waveform
|
| 1062 |
+
|
| 1063 |
+
# 归一化振幅(避免多个频率叠加导致音量过高)
|
| 1064 |
+
waveform /= len(frequencies) * sum(harmonic_coeffs)
|
| 1065 |
+
|
| 1066 |
+
# 添加 ADSR 包络
|
| 1067 |
+
attack_samples = int(envelope_attack * sample_rate)
|
| 1068 |
+
decay_samples = int(envelope_decay * sample_rate)
|
| 1069 |
+
release_samples = int(envelope_release * sample_rate)
|
| 1070 |
+
sustain_samples = len(waveform) - attack_samples - decay_samples - release_samples
|
| 1071 |
+
if sustain_samples < 0:
|
| 1072 |
+
# 调整 ADSR 以适应短音符
|
| 1073 |
+
total_envelope = envelope_attack + envelope_decay + envelope_release
|
| 1074 |
+
attack_ratio = envelope_attack / total_envelope
|
| 1075 |
+
decay_ratio = envelope_decay / total_envelope
|
| 1076 |
+
release_ratio = envelope_release / total_envelope
|
| 1077 |
+
attack_samples = int(len(waveform) * attack_ratio)
|
| 1078 |
+
decay_samples = int(len(waveform) * decay_ratio)
|
| 1079 |
+
release_samples = len(waveform) - attack_samples - decay_samples
|
| 1080 |
+
sustain_samples = 0
|
| 1081 |
+
|
| 1082 |
+
envelope = np.concatenate([
|
| 1083 |
+
np.linspace(0, 1, attack_samples, False),
|
| 1084 |
+
np.linspace(1, envelope_sustain, decay_samples, False),
|
| 1085 |
+
np.full(sustain_samples, envelope_sustain),
|
| 1086 |
+
np.linspace(envelope_sustain, 0, release_samples, False)
|
| 1087 |
+
])
|
| 1088 |
+
|
| 1089 |
+
# 调整 envelope 长度
|
| 1090 |
+
envelope = envelope[:len(waveform)]
|
| 1091 |
+
|
| 1092 |
+
waveform *= envelope
|
| 1093 |
+
waveform *= volume
|
| 1094 |
+
|
| 1095 |
+
# 计算样本索引
|
| 1096 |
+
start_sample = int(current_time * sample_rate)
|
| 1097 |
+
end_sample = start_sample + len(waveform)
|
| 1098 |
+
if end_sample > total_samples:
|
| 1099 |
+
end_sample = total_samples
|
| 1100 |
+
waveform = waveform[:end_sample - start_sample]
|
| 1101 |
+
|
| 1102 |
+
# 合成音频
|
| 1103 |
+
audio[start_sample:end_sample] += waveform
|
| 1104 |
+
|
| 1105 |
+
# 更新当前时间
|
| 1106 |
+
current_time += duration
|
| 1107 |
+
|
| 1108 |
+
# 防止削波
|
| 1109 |
+
max_val = np.max(np.abs(audio))
|
| 1110 |
+
if max_val > 1:
|
| 1111 |
+
audio /= max_val
|
| 1112 |
+
|
| 1113 |
+
# 将音频转换为16位整数
|
| 1114 |
+
audio_int16 = np.int16(audio * 32767)
|
| 1115 |
+
|
| 1116 |
+
# 写入 WAV 文件
|
| 1117 |
+
wav_path = os.path.join(os.getcwd(), wav_filename)
|
| 1118 |
+
with wave.open(wav_path, 'w') as wav_file:
|
| 1119 |
+
n_channels = 2
|
| 1120 |
+
sampwidth = 2 # 2 bytes for int16
|
| 1121 |
+
framerate = sample_rate
|
| 1122 |
+
n_frames = len(audio_int16)
|
| 1123 |
+
comptype = "NONE"
|
| 1124 |
+
compname = "not compressed"
|
| 1125 |
+
wav_file.setparams((n_channels, sampwidth, framerate, n_frames, comptype, compname))
|
| 1126 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 1127 |
+
|
| 1128 |
+
return wav_path
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
class AdvancedMusicGenerator(MusicGenerator):
|
| 1132 |
+
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
|
| 1133 |
+
super().__init__(model, evaluator, device, model_path, optimizer, optimizer_path, writer)
|
| 1134 |
+
# 可以在此添加更多的初始化参数或方法
|
| 1135 |
+
|
| 1136 |
+
# 这里可以覆盖或新增更多方法以进一步增强功能
|
| 1137 |
+
|
| 1138 |
+
def trainModel():
|
| 1139 |
+
# 初始化 TensorBoard
|
| 1140 |
+
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
|
| 1141 |
+
|
| 1142 |
+
# 初始化标签评估器
|
| 1143 |
+
evaluator = MusicTagEvaluator.load(EvaluatorPath)
|
| 1144 |
+
|
| 1145 |
+
# 获取唯一的标签数量
|
| 1146 |
+
num_tags = len(evaluator.all_tags)
|
| 1147 |
+
|
| 1148 |
+
# 定义模型参数
|
| 1149 |
+
input_dim = 3 # 音高、时值和音量
|
| 1150 |
+
d_model = 512 # 增加 Transformer 模型维度
|
| 1151 |
+
nhead = 8 # 多头注意力头数
|
| 1152 |
+
num_encoder_layers = 8 # 增加 Transformer 编码器层数
|
| 1153 |
+
dim_feedforward = 2048 # 增加前馈层维度
|
| 1154 |
+
output_dim = 3 # 预测音高、时值和音量
|
| 1155 |
+
|
| 1156 |
+
# 初始化模型
|
| 1157 |
+
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags)
|
| 1158 |
+
|
| 1159 |
+
# 设置设备
|
| 1160 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 1161 |
+
model.to(device)
|
| 1162 |
+
print(f"使用设备: {device}")
|
| 1163 |
+
|
| 1164 |
+
# 加载 MIDI 文件
|
| 1165 |
+
midi_directory = os.path.join(Gbase, 'generateMIDI')
|
| 1166 |
+
midi_files = []
|
| 1167 |
+
if os.path.exists(midi_directory):
|
| 1168 |
+
midi_files = [os.path.join(midi_directory, f) for f in os.listdir(midi_directory) if f.endswith('.mid') or f.endswith('.midi')]
|
| 1169 |
+
print(f"在目录 '{midi_directory}' 中找到 {len(midi_files)} 个 MIDI 文件用于训练。")
|
| 1170 |
+
else:
|
| 1171 |
+
print(f"MIDI 文件目录 '{midi_directory}' 不存在,请确保该目录存在并包含 MIDI 文件。")
|
| 1172 |
+
return # 退出函数
|
| 1173 |
+
|
| 1174 |
+
# 创建数据集和数据加载器
|
| 1175 |
+
max_length = 100 # 根据需求调整
|
| 1176 |
+
dataset_path = os.path.join(Gbase, 'mymusic.dataset')
|
| 1177 |
+
dataset = MidiDataset(midi_files, max_length, dataset_path, evaluator)
|
| 1178 |
+
datasetAug = MidiDatasetAug(midi_files, max_length, dataset_path, evaluator)
|
| 1179 |
+
# 定义要采样的样本数量
|
| 1180 |
+
sample_size = 30000 if torch.cuda.is_available() else 15000
|
| 1181 |
+
sample_size1 = int(sample_size/10)
|
| 1182 |
+
sample_size2 = int(sample_size/300)
|
| 1183 |
+
total_samples = len(dataset)
|
| 1184 |
+
if total_samples < sample_size:
|
| 1185 |
+
print(f"数据集中只有 {total_samples} 个样本,无法采样 {sample_size} 个。请检查数据集。")
|
| 1186 |
+
return
|
| 1187 |
+
|
| 1188 |
+
# 定义训练周期和学习率
|
| 1189 |
+
epochs = 4 # 根据需要调整
|
| 1190 |
+
learning_rate = 0.001
|
| 1191 |
+
batch_size= 16 if torch.cuda.is_available() else 4
|
| 1192 |
+
|
| 1193 |
+
# 初始化生成器
|
| 1194 |
+
optimizer_generator = optim.AdamW(model.parameters(), lr=learning_rate * 0.1)
|
| 1195 |
+
generator = MusicGenerator(model, evaluator, device, model_path=ModelPath, optimizer=optimizer_generator, optimizer_path=OptimizerPath, writer=writer)
|
| 1196 |
+
|
| 1197 |
+
# 初始化判别器
|
| 1198 |
+
discriminator = Discriminator(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward).to(device)
|
| 1199 |
+
optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate)
|
| 1200 |
+
criterion_discriminator = nn.BCELoss()
|
| 1201 |
+
|
| 1202 |
+
# 尝试加载判别器模型和优化器状态
|
| 1203 |
+
if os.path.exists(DiscriminatorModelPath):
|
| 1204 |
+
discriminator.load_state_dict(torch.load(DiscriminatorModelPath, map_location=device))
|
| 1205 |
+
print(f"已成功载入判别器模型权重从 '{DiscriminatorModelPath}'。")
|
| 1206 |
+
if os.path.exists(DiscriminatorOptimizerPath):
|
| 1207 |
+
optimizer_discriminator.load_state_dict(torch.load(DiscriminatorOptimizerPath, map_location=device))
|
| 1208 |
+
print(f"已成功载入判别器优化器状态从 '{DiscriminatorOptimizerPath}'。")
|
| 1209 |
+
indices = list(range(total_samples))
|
| 1210 |
+
random_indices = random.sample(indices, sample_size)
|
| 1211 |
+
random_indices1 = random.sample(indices, sample_size1)
|
| 1212 |
+
random_indices2 = random.sample(indices, sample_size2)
|
| 1213 |
+
random_indicesAug= random.sample(indices, sample_size)
|
| 1214 |
+
sampler = SubsetRandomSampler(random_indices)
|
| 1215 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
|
| 1216 |
+
sampler1 = SubsetRandomSampler(random_indices1)
|
| 1217 |
+
dataloaderAug = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler, num_workers=2)
|
| 1218 |
+
dataloader1 = DataLoader(datasetAug, batch_size=8, sampler=sampler1, num_workers=2)
|
| 1219 |
+
sampler2 = SubsetRandomSampler(random_indices2)
|
| 1220 |
+
dataloader2 = DataLoader(dataset, batch_size=batch_size, sampler=sampler2, num_workers=2)
|
| 1221 |
+
sampler3 = SubsetRandomSampler(random_indices2)
|
| 1222 |
+
dataloader3 = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler2, num_workers=2)
|
| 1223 |
+
# 开始对抗训练
|
| 1224 |
+
print("開始訓練...")
|
| 1225 |
+
for epoch in range(1, epochs + 1):
|
| 1226 |
+
try:
|
| 1227 |
+
avg_loss = generator.train_epoch(
|
| 1228 |
+
dataloader,
|
| 1229 |
+
optimizer_generator,
|
| 1230 |
+
nn.MSELoss(),
|
| 1231 |
+
nn.BCELoss(),
|
| 1232 |
+
epoch
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
# 保存判别器模型和优化器
|
| 1236 |
+
"""
|
| 1237 |
+
generator.save_model(epoch, avg_loss)
|
| 1238 |
+
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
|
| 1239 |
+
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
|
| 1240 |
+
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
|
| 1241 |
+
|
| 1242 |
+
"""
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
except KeyboardInterrupt:
|
| 1247 |
+
print("训练过程被手动中断。")
|
| 1248 |
+
break
|
| 1249 |
+
except Exception as e:
|
| 1250 |
+
print(f"在训练 epoch {epoch} 时发生错误: {e}")
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
+
if epoch!=4:continue
|
| 1254 |
+
print("開始強化訓練...")
|
| 1255 |
+
try:
|
| 1256 |
+
avg_loss = generator.train_epoch(
|
| 1257 |
+
dataloaderAug,
|
| 1258 |
+
optimizer_generator,
|
| 1259 |
+
nn.MSELoss(),
|
| 1260 |
+
nn.BCELoss(),
|
| 1261 |
+
epoch
|
| 1262 |
+
)
|
| 1263 |
+
|
| 1264 |
+
# 保存判别器模型和优化器
|
| 1265 |
+
#"""
|
| 1266 |
+
generator.save_model(epoch, avg_loss)
|
| 1267 |
+
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
|
| 1268 |
+
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
|
| 1269 |
+
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
|
| 1270 |
+
# 保存评估器
|
| 1271 |
+
#evaluator.save(EvaluatorPath)
|
| 1272 |
+
#"""
|
| 1273 |
+
except KeyboardInterrupt:
|
| 1274 |
+
print("训练过程被手动中断。")
|
| 1275 |
+
break
|
| 1276 |
+
except Exception as e:
|
| 1277 |
+
print(f"在训练 epoch {epoch} 时发生错误: {e}")
|
| 1278 |
+
print("開始對抗訓練...")
|
| 1279 |
+
try:
|
| 1280 |
+
avg_loss = generator.train_epoch_gan(
|
| 1281 |
+
dataloader1,
|
| 1282 |
+
optimizer_generator,
|
| 1283 |
+
optimizer_discriminator,
|
| 1284 |
+
nn.MSELoss(),
|
| 1285 |
+
nn.BCELoss(),
|
| 1286 |
+
criterion_discriminator,
|
| 1287 |
+
discriminator,
|
| 1288 |
+
epoch
|
| 1289 |
+
)
|
| 1290 |
+
"""
|
| 1291 |
+
generator.save_model(epoch, avg_loss)
|
| 1292 |
+
# 保存判别器模型和优化器
|
| 1293 |
+
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
|
| 1294 |
+
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
|
| 1295 |
+
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
|
| 1296 |
+
# 保存评估器
|
| 1297 |
+
#evaluator.save(EvaluatorPath)
|
| 1298 |
+
#"""
|
| 1299 |
+
except KeyboardInterrupt:
|
| 1300 |
+
print("训练过程被手动中断。")
|
| 1301 |
+
break
|
| 1302 |
+
except Exception as e:
|
| 1303 |
+
print(f"在训练 epoch {epoch} 时发生错误: {e}")
|
| 1304 |
+
continue # 继续下一个 epoch
|
| 1305 |
+
|
| 1306 |
+
# 关闭 TensorBoard writer
|
| 1307 |
+
writer.close()
|
| 1308 |
+
|
| 1309 |
+
def loadMusicGenerator():
|
| 1310 |
+
# 初始化 TensorBoard
|
| 1311 |
+
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
|
| 1312 |
+
|
| 1313 |
+
# 加载标签评估器
|
| 1314 |
+
evaluator = MusicTagEvaluator()
|
| 1315 |
+
#.load(EvaluatorPath)
|
| 1316 |
+
|
| 1317 |
+
# 获取唯一的标签数量
|
| 1318 |
+
num_tags = len(evaluator.all_tags)
|
| 1319 |
+
|
| 1320 |
+
# 定义模型参数
|
| 1321 |
+
input_dim = 3 # 音高、时值和音量
|
| 1322 |
+
d_model = 512 # 必须与训练时的模型参数一致
|
| 1323 |
+
nhead = 8
|
| 1324 |
+
num_encoder_layers = 8
|
| 1325 |
+
dim_feedforward = 2048
|
| 1326 |
+
output_dim = 3
|
| 1327 |
+
# 设置设备
|
| 1328 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 1329 |
+
print(f"使用设备: {device}")
|
| 1330 |
+
|
| 1331 |
+
# 初始化模型
|
| 1332 |
+
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags).to(device)
|
| 1333 |
+
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
# 初始化生成器
|
| 1337 |
+
generator = AdvancedMusicGenerator(model, evaluator, device, model_path=ModelPath, writer=writer)
|
| 1338 |
+
return generator, evaluator
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
MyMusicGenerator, MyMusicTagEvaluator = loadMusicGenerator()
|
| 1342 |
+
|
| 1343 |
+
import gradio as gr
|
| 1344 |
+
import numpy as np
|
| 1345 |
+
import time
|
| 1346 |
+
import os
|
| 1347 |
+
|
| 1348 |
+
# Assuming your existing functions and setup are defined above
|
| 1349 |
+
|
| 1350 |
+
def generate_music(*tags, use_random=False):
|
| 1351 |
+
if use_random:
|
| 1352 |
+
tags_dict = randomMusicTags()
|
| 1353 |
+
else:
|
| 1354 |
+
# Assuming the order of tags matches with MUSIC_TAGS.keys()
|
| 1355 |
+
tags_dict = dict(zip(MUSIC_TAGS.keys(), tags))
|
| 1356 |
+
|
| 1357 |
+
# Generate music using your existing function (which should return a path to a wav file)
|
| 1358 |
+
generated_stream = MyMusicGenerator.generate_music(tag_conditions=tags_dict, max_length=130, temperature=np.random.uniform(0.7, 1.1))
|
| 1359 |
+
|
| 1360 |
+
# Save the generated stream as a MIDI file
|
| 1361 |
+
midi_filename = f"music_{int(time.time())}.mid"
|
| 1362 |
+
mid_path = os.path.join(Gbase, midi_filename)
|
| 1363 |
+
generated_stream.write('midi', fp=mid_path)
|
| 1364 |
+
|
| 1365 |
+
# Convert MIDI to WAV (make sure this function exists)
|
| 1366 |
+
wav_file = MyMusicGenerator.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f"{midi_filename[:-4]}.wav"))
|
| 1367 |
+
|
| 1368 |
+
return wav_file, tags_dict
|
| 1369 |
+
|
| 1370 |
+
# Define the interface
|
| 1371 |
+
with gr.Blocks() as demo:
|
| 1372 |
+
gr.Markdown("# Music Generation with Tags")
|
| 1373 |
+
|
| 1374 |
+
with gr.Row():
|
| 1375 |
+
with gr.Column():
|
| 1376 |
+
# List comprehension to create dropdowns for each tag category
|
| 1377 |
+
tag_inputs = [
|
| 1378 |
+
gr.Dropdown(value=MUSIC_TAGS[category][0] ,choices=MUSIC_TAGS[category], label=category.capitalize())
|
| 1379 |
+
for category in MUSIC_TAGS.keys()
|
| 1380 |
+
]
|
| 1381 |
+
with gr.Column():
|
| 1382 |
+
use_random = gr.Checkbox(label="Use Random Tags")
|
| 1383 |
+
generate_btn = gr.Button("Generate Music")
|
| 1384 |
+
output_audio = gr.Audio(label="Generated Music")
|
| 1385 |
+
output_tags = gr.JSON(label="Generated Tags")
|
| 1386 |
+
|
| 1387 |
+
# Pass the list of dropdowns directly instead of using gr.Group
|
| 1388 |
+
generate_btn.click(
|
| 1389 |
+
fn=generate_music,
|
| 1390 |
+
inputs=[*tag_inputs, use_random],
|
| 1391 |
+
outputs=[output_audio, output_tags]
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
# Launch the interface
|
| 1395 |
+
demo.launch()
|