Spaces:
Running
Running
Acelle Krislette Rosales
commited on
Commit
·
fc7b4a9
1
Parent(s):
fa617cf
Initial commit: Added application code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +40 -0
- Dockerfile +50 -0
- app/__init__.py +0 -0
- app/schemas.py +47 -0
- app/server.py +202 -0
- app/utils.py +16 -0
- config/data_config.yml +8 -0
- config/model_config.yml +11 -0
- config/server_config.yml +25 -0
- models/llm2vec/.gitkeep +0 -0
- models/spectttra/.gitkeep +0 -0
- poetry.lock +0 -0
- pyproject.toml +47 -0
- scripts/evaluate.py +164 -0
- scripts/explain.py +74 -0
- scripts/explain_test.py +75 -0
- scripts/explain_with_json.py +97 -0
- scripts/predict.py +82 -0
- scripts/train.py +160 -0
- src/__init__.py +0 -0
- src/features/__init__.py +0 -0
- src/features/llm2vec.py +0 -0
- src/features/spectttra.py +0 -0
- src/llm2vectrain/__init__.py +0 -0
- src/llm2vectrain/__pycache__/__init__.cpython-312.pyc +0 -0
- src/llm2vectrain/__pycache__/access_token.cpython-312.pyc +0 -0
- src/llm2vectrain/__pycache__/llm2vec_trainer.cpython-312.pyc +0 -0
- src/llm2vectrain/__pycache__/model.cpython-312.pyc +0 -0
- src/llm2vectrain/config.py +5 -0
- src/llm2vectrain/llm2vec_trainer.py +159 -0
- src/llm2vectrain/model.py +51 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-312.pyc +0 -0
- src/models/__pycache__/mlp.cpython-312.pyc +0 -0
- src/models/fusion.py +0 -0
- src/models/mlp.py +753 -0
- src/musiclime/__init__.py +0 -0
- src/musiclime/__pycache__/__init__.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/__init__.cpython-313.pyc +0 -0
- src/musiclime/__pycache__/explainer.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/explainer.cpython-313.pyc +0 -0
- src/musiclime/__pycache__/factorization.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/musiclime.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/musiclime_wrapper.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/optimized_wrapper.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/print_utils.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/text_utils.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/true_musiclime.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/utils.cpython-312.pyc +0 -0
- src/musiclime/__pycache__/wrapper.cpython-312.pyc +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.so
|
| 8 |
+
.pytest_cache/
|
| 9 |
+
.coverage
|
| 10 |
+
|
| 11 |
+
# Env
|
| 12 |
+
.env
|
| 13 |
+
|
| 14 |
+
# Virtual environments
|
| 15 |
+
.venv/
|
| 16 |
+
.env.local
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
*.swp
|
| 22 |
+
*.swo
|
| 23 |
+
|
| 24 |
+
# OS
|
| 25 |
+
.DS_Store
|
| 26 |
+
Thumbs.db
|
| 27 |
+
|
| 28 |
+
# Logs
|
| 29 |
+
*.log
|
| 30 |
+
logs/
|
| 31 |
+
|
| 32 |
+
# Model cache (let HF download fresh)
|
| 33 |
+
.cache/
|
| 34 |
+
models/.cache/
|
| 35 |
+
|
| 36 |
+
# Development files
|
| 37 |
+
.pytest_cache/
|
| 38 |
+
notebooks/
|
| 39 |
+
tests/
|
| 40 |
+
docs/
|
Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use CUDA base for GPU support
|
| 2 |
+
FROM nvidia/cuda:13.0.1-runtime-ubuntu22.04
|
| 3 |
+
|
| 4 |
+
# Set timezone non-interactively
|
| 5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 6 |
+
ENV TZ=UTC
|
| 7 |
+
|
| 8 |
+
# Install Python and basic dependencies
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
software-properties-common \
|
| 11 |
+
&& add-apt-repository ppa:deadsnakes/ppa \
|
| 12 |
+
&& apt-get update && apt-get install -y \
|
| 13 |
+
python3.11 \
|
| 14 |
+
python3.11-dev \
|
| 15 |
+
python3.11-venv \
|
| 16 |
+
python3.11-distutils \
|
| 17 |
+
git \
|
| 18 |
+
libsndfile1 \
|
| 19 |
+
ffmpeg \
|
| 20 |
+
curl \
|
| 21 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 22 |
+
&& ln -sf /usr/bin/python3.11 /usr/bin/python3 \
|
| 23 |
+
&& ln -sf /usr/bin/python3.11 /usr/bin/python \
|
| 24 |
+
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
|
| 25 |
+
|
| 26 |
+
WORKDIR /app
|
| 27 |
+
|
| 28 |
+
# Copy and install Python dependencies
|
| 29 |
+
COPY pyproject.toml poetry.lock* ./
|
| 30 |
+
RUN python3.11 -m pip install poetry && \
|
| 31 |
+
poetry config virtualenvs.create false && \
|
| 32 |
+
poetry install --only=main
|
| 33 |
+
|
| 34 |
+
# Copy application code
|
| 35 |
+
COPY src/ ./src/
|
| 36 |
+
COPY app/ ./app/
|
| 37 |
+
COPY config/ ./config/
|
| 38 |
+
COPY models/ ./models/
|
| 39 |
+
COPY scripts/ ./scripts/
|
| 40 |
+
COPY .env ./
|
| 41 |
+
|
| 42 |
+
# Set environment
|
| 43 |
+
ENV PYTHONPATH="/app"
|
| 44 |
+
ENV HF_HOME="/app/.cache/huggingface"
|
| 45 |
+
|
| 46 |
+
# Hugging Face Spaces specific, expose port 7860
|
| 47 |
+
EXPOSE 7860
|
| 48 |
+
|
| 49 |
+
# Run on port 7860 for HF Spaces
|
| 50 |
+
CMD ["uvicorn", "app.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
|
File without changes
|
app/schemas.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Pydantic model for the base response
|
| 6 |
+
class BaseResponse(BaseModel):
|
| 7 |
+
status: str
|
| 8 |
+
message: Optional[str] = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WelcomeResponse(BaseResponse):
|
| 12 |
+
endpoints: Dict[str, str]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelInfoResponse(BaseResponse):
|
| 16 |
+
model_name: str
|
| 17 |
+
model_version: str
|
| 18 |
+
supported_formats: List[str]
|
| 19 |
+
max_file_size_mb: int
|
| 20 |
+
training_info: Optional[Dict] = None
|
| 21 |
+
last_updated: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Pydantic model for the prediction response
|
| 25 |
+
class PredictionResponse(BaseModel):
|
| 26 |
+
status: str
|
| 27 |
+
lyrics: str
|
| 28 |
+
audio_file_name: str
|
| 29 |
+
audio_content_type: str
|
| 30 |
+
audio_file_size: int
|
| 31 |
+
results: Optional[Dict] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PredictionXAIResponse(BaseModel):
|
| 35 |
+
status: str
|
| 36 |
+
lyrics: str
|
| 37 |
+
audio_file_name: str
|
| 38 |
+
audio_content_type: str
|
| 39 |
+
audio_file_size: int
|
| 40 |
+
results: Optional[Dict] = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Pydantic model for the error response
|
| 44 |
+
class ErrorResponse(BaseModel):
|
| 45 |
+
status: str = "error"
|
| 46 |
+
code: int
|
| 47 |
+
message: str
|
app/server.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fast API imports
|
| 2 |
+
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
|
| 5 |
+
# Processing imports
|
| 6 |
+
import librosa
|
| 7 |
+
import io
|
| 8 |
+
|
| 9 |
+
# Utils/schemas imports
|
| 10 |
+
from app.schemas import (
|
| 11 |
+
ErrorResponse,
|
| 12 |
+
ModelInfoResponse,
|
| 13 |
+
PredictionResponse,
|
| 14 |
+
PredictionXAIResponse,
|
| 15 |
+
WelcomeResponse,
|
| 16 |
+
)
|
| 17 |
+
from app.utils import load_config
|
| 18 |
+
|
| 19 |
+
# Model/XAI-related imports
|
| 20 |
+
from scripts.explain import musiclime
|
| 21 |
+
from scripts.predict import predict_pipeline
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Load config at startup
|
| 25 |
+
config = load_config()
|
| 26 |
+
|
| 27 |
+
# Extract configuration values
|
| 28 |
+
MAX_FILE_SIZE = config["file_upload"]["max_file_size_mb"] * 1024 * 1024
|
| 29 |
+
MAX_LYRICS_LENGTH = config["file_upload"]["max_lyrics_length"]
|
| 30 |
+
ALLOWED_AUDIO_TYPES = config["file_upload"]["allowed_audio_types"]
|
| 31 |
+
|
| 32 |
+
# Initialize fast API app with extracted config values
|
| 33 |
+
app = FastAPI(title=config["server"]["title"], version=config["server"]["version"])
|
| 34 |
+
|
| 35 |
+
# Initialize CORS with config values
|
| 36 |
+
cors_config = config["api"]["cors"]
|
| 37 |
+
app.add_middleware(
|
| 38 |
+
CORSMiddleware,
|
| 39 |
+
allow_origins=cors_config["allow_origins"],
|
| 40 |
+
allow_credentials=cors_config["allow_credentials"],
|
| 41 |
+
allow_methods=cors_config["allow_methods"],
|
| 42 |
+
allow_headers=cors_config["allow_headers"],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def validate_audio_file(audio_file: UploadFile = File(...)):
|
| 47 |
+
"""Validate audio file type and size."""
|
| 48 |
+
# Check file size
|
| 49 |
+
audio_content = await audio_file.read()
|
| 50 |
+
if len(audio_content) > MAX_FILE_SIZE:
|
| 51 |
+
raise HTTPException(
|
| 52 |
+
status_code=400,
|
| 53 |
+
detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Check file type
|
| 57 |
+
if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
|
| 58 |
+
raise HTTPException(
|
| 59 |
+
status_code=400,
|
| 60 |
+
detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Reset file pointer for later use
|
| 64 |
+
audio_file.file.seek(0)
|
| 65 |
+
return audio_file, audio_content
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def validate_lyrics(lyrics: str = Form(...)):
|
| 69 |
+
"""Validate lyrics length and content."""
|
| 70 |
+
if len(lyrics) > MAX_LYRICS_LENGTH:
|
| 71 |
+
raise HTTPException(
|
| 72 |
+
status_code=400,
|
| 73 |
+
detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Basic sanitization, remove excessive whitespace
|
| 77 |
+
lyrics = lyrics.strip()
|
| 78 |
+
if not lyrics:
|
| 79 |
+
raise HTTPException(
|
| 80 |
+
status_code=400,
|
| 81 |
+
detail="Lyrics cannot be empty.",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return lyrics
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@app.get("/", response_model=WelcomeResponse, tags=["Root"])
|
| 88 |
+
def root():
|
| 89 |
+
"""
|
| 90 |
+
Root endpoint to check if the API is running.
|
| 91 |
+
"""
|
| 92 |
+
return WelcomeResponse(
|
| 93 |
+
status="success",
|
| 94 |
+
message="Welcome to Bach or Bot API!",
|
| 95 |
+
endpoints={
|
| 96 |
+
"/": "This welcome message",
|
| 97 |
+
"/docs": "FastAPI auto-generated API docs",
|
| 98 |
+
"/api/v1/model/info": "Model information and capabilities",
|
| 99 |
+
"/api/v1/predict": "POST endpoint for bach-or-bot prediction",
|
| 100 |
+
"/api/v1/explain": "POST endpoint for prediction with explainability",
|
| 101 |
+
},
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@app.post(
|
| 106 |
+
"/api/v1/predict",
|
| 107 |
+
response_model=PredictionResponse,
|
| 108 |
+
responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
|
| 109 |
+
)
|
| 110 |
+
async def predict_music(
|
| 111 |
+
lyrics: str = Depends(validate_lyrics), audio_file_data=Depends(validate_audio_file)
|
| 112 |
+
):
|
| 113 |
+
"""
|
| 114 |
+
Endpoint to predict whether a music sample is human-composed or AI-generated.
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
# Get the audio file and content from sanitized and cleaned audio file
|
| 118 |
+
audio_file, audio_content = audio_file_data
|
| 119 |
+
|
| 120 |
+
# Load audio from uploaded file with error handling for corrupted files
|
| 121 |
+
try:
|
| 122 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 123 |
+
except Exception as e:
|
| 124 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 125 |
+
|
| 126 |
+
# Call MLP predict runner script to get results
|
| 127 |
+
results = predict_pipeline(audio_data, lyrics)
|
| 128 |
+
|
| 129 |
+
return PredictionResponse(
|
| 130 |
+
status="success",
|
| 131 |
+
lyrics=lyrics,
|
| 132 |
+
audio_file_name=audio_file.filename,
|
| 133 |
+
audio_content_type=audio_file.content_type,
|
| 134 |
+
audio_file_size=len(audio_content),
|
| 135 |
+
results=results,
|
| 136 |
+
)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@app.post(
|
| 142 |
+
"/api/v1/explain",
|
| 143 |
+
response_model=PredictionXAIResponse,
|
| 144 |
+
responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
|
| 145 |
+
)
|
| 146 |
+
async def predict_music_with_xai(
|
| 147 |
+
lyrics: str = Depends(validate_lyrics), audio_file_data=Depends(validate_audio_file)
|
| 148 |
+
):
|
| 149 |
+
"""
|
| 150 |
+
Endpoint to predict whether a music sample is human-composed or AI-generated with explainability.
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
# Get the audio file and content from sanitized and cleaned audio file
|
| 154 |
+
audio_file, audio_content = audio_file_data
|
| 155 |
+
|
| 156 |
+
# Load audio from uploaded file with error handling for corrupted files
|
| 157 |
+
try:
|
| 158 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 159 |
+
except Exception as e:
|
| 160 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 161 |
+
|
| 162 |
+
# Call musiclime runner script to get results
|
| 163 |
+
results = musiclime(audio_data, lyrics)
|
| 164 |
+
|
| 165 |
+
return PredictionXAIResponse(
|
| 166 |
+
status="success",
|
| 167 |
+
lyrics=lyrics,
|
| 168 |
+
audio_file_name=audio_file.filename,
|
| 169 |
+
audio_content_type=audio_file.content_type,
|
| 170 |
+
audio_file_size=len(audio_content),
|
| 171 |
+
results=results,
|
| 172 |
+
)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
|
| 178 |
+
async def get_model_info():
|
| 179 |
+
"""
|
| 180 |
+
Get information about the current model and its capabilities.
|
| 181 |
+
"""
|
| 182 |
+
try:
|
| 183 |
+
# Get supported formats from config
|
| 184 |
+
supported_formats = [fmt.replace("audio/", "") for fmt in ALLOWED_AUDIO_TYPES]
|
| 185 |
+
|
| 186 |
+
return ModelInfoResponse(
|
| 187 |
+
status="success",
|
| 188 |
+
message="Model information retrieved successfully",
|
| 189 |
+
model_name="Bach or Bot",
|
| 190 |
+
model_version="1.0.0", # TODO: Load from model metadata when available
|
| 191 |
+
supported_formats=supported_formats,
|
| 192 |
+
max_file_size_mb=config["file_upload"]["max_file_size_mb"],
|
| 193 |
+
training_info={
|
| 194 |
+
"dataset": "Human-Composed and AI-generated music samples",
|
| 195 |
+
"architecture": "To be specified", # TODO: Update when model is implemented
|
| 196 |
+
"accuracy": "To be determined", # TODO: Update with actual metrics
|
| 197 |
+
},
|
| 198 |
+
last_updated="2024-01-01T00:00:00Z", # TODO: Update with actual timestamp
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
except Exception as e:
|
| 202 |
+
raise HTTPException(status_code=500, detail=str(e))
|
app/utils.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_config():
|
| 6 |
+
"""
|
| 7 |
+
Load server configs from YAML file.
|
| 8 |
+
"""
|
| 9 |
+
# Define path first
|
| 10 |
+
config_path = Path(__file__).parent.parent / "config" / "server_config.yml"
|
| 11 |
+
|
| 12 |
+
if not config_path.exists():
|
| 13 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 14 |
+
|
| 15 |
+
with open(config_path, "r") as file:
|
| 16 |
+
return yaml.safe_load(file)
|
config/data_config.yml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_dir: "."
|
| 2 |
+
|
| 3 |
+
paths:
|
| 4 |
+
dataset_npz: "data/processed/training_data.npz"
|
| 5 |
+
dataset_csv: "data/external/songs_dataset.csv"
|
| 6 |
+
raw_dir: "data/raw"
|
| 7 |
+
processed_dir: "data/processed"
|
| 8 |
+
pca_path: "data/processed/pca_model.pkl"
|
config/model_config.yml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mlp:
|
| 2 |
+
hidden_layers: [1024, 512, 256, 128, 64, 32] # 6 hidden layers
|
| 3 |
+
dropout: [0.4, 0.3, 0.5, 0.5, 0.5] # Dropout rates for each layer
|
| 4 |
+
learning_rate: 0.0001 # Adam optimizer
|
| 5 |
+
batch_size: 128 # Number of samples processed together
|
| 6 |
+
epochs: 200 # Maximum training iterations
|
| 7 |
+
patience: 5 # Early stopping patience
|
| 8 |
+
|
| 9 |
+
weight_decay: 0.1 # L2 regularization
|
| 10 |
+
gradient_clipping: 0.5 # Prevent exploding gradients
|
| 11 |
+
mixup_alpha: 0.2 # For data augmentation during trainign, 0 disables MixUp
|
config/server_config.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server Configuration
|
| 2 |
+
server:
|
| 3 |
+
title: "Bach or Bot API"
|
| 4 |
+
version: "1.0.0"
|
| 5 |
+
|
| 6 |
+
# File upload limits and validation
|
| 7 |
+
file_upload:
|
| 8 |
+
# Maximum file size in MB
|
| 9 |
+
max_file_size_mb: 10
|
| 10 |
+
# Maximum characters for lyrics
|
| 11 |
+
max_lyrics_length: 10000
|
| 12 |
+
allowed_audio_types:
|
| 13 |
+
- "audio/wav"
|
| 14 |
+
- "audio/mpeg"
|
| 15 |
+
- "audio/mp3"
|
| 16 |
+
- "application/octet-stream"
|
| 17 |
+
|
| 18 |
+
# API Configuration
|
| 19 |
+
api:
|
| 20 |
+
cors:
|
| 21 |
+
# TODO: Change to specific origins in production
|
| 22 |
+
allow_origins: ["*"]
|
| 23 |
+
allow_credentials: true
|
| 24 |
+
allow_methods: ["*"]
|
| 25 |
+
allow_headers: ["*"]
|
models/llm2vec/.gitkeep
ADDED
|
File without changes
|
models/spectttra/.gitkeep
ADDED
|
File without changes
|
poetry.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "bach-or-bot"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "A binary classifier to distinguish between Human-composed and AI-generated music"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "Acelle Krislette Rosales",email = "acellekrislette@gmail.com"},
|
| 7 |
+
{name = "Hans Christian Queja",email = "hansqueja8@gmail.com"},
|
| 8 |
+
{name = "Regina Bonfiacio",email = "bonifacioregina06@gmail.com"},
|
| 9 |
+
{name = "Sean Matthew Sinalubong",email = "s3amatth3wsinalubong@gmail.com"},
|
| 10 |
+
{name = "Syruz Ken Domingo",email = "syruzkenc.domingo@gmail.com"},
|
| 11 |
+
]
|
| 12 |
+
license = {text = "MIT"}
|
| 13 |
+
readme = "README.md"
|
| 14 |
+
requires-python = ">=3.11,<3.14"
|
| 15 |
+
dependencies = [
|
| 16 |
+
"librosa (>=0.11.0,<0.12.0)",
|
| 17 |
+
"pandas (>=2.3.2,<3.0.0)",
|
| 18 |
+
"soundfile (>=0.13.1,<0.14.0)",
|
| 19 |
+
"torchaudio (>=2.8.0,<3.0.0)",
|
| 20 |
+
"transformers (==4.44.2)",
|
| 21 |
+
"llm2vec (>=0.2.3,<0.3.0)",
|
| 22 |
+
"peft (>=0.17.1,<0.18.0)",
|
| 23 |
+
"timm (>=1.0.19,<2.0.0)",
|
| 24 |
+
"pyyaml (>=6.0.2,<7.0.0)",
|
| 25 |
+
"tqdm (>=4.67.1,<5.0.0)",
|
| 26 |
+
"torch (>=2.8.0,<3.0.0)",
|
| 27 |
+
"openunmix (>=1.3.0,<2.0.0)",
|
| 28 |
+
"fastapi (>=0.117.1,<0.118.0)",
|
| 29 |
+
"uvicorn (>=0.36.0,<0.37.0)",
|
| 30 |
+
"scikit-learn (>=1.5.2)",
|
| 31 |
+
"torchao (>=0.13.0,<0.14.0)",
|
| 32 |
+
"lime (>=0.2.0.1,<0.3.0.0)",
|
| 33 |
+
"hf-xet (>=1.1.10,<2.0.0)",
|
| 34 |
+
"huggingface-hub[cli] (>=0.35.3,<0.36.0)",
|
| 35 |
+
"pytest (>=8.4.2,<9.0.0)",
|
| 36 |
+
"python-multipart (>=0.0.20,<0.0.21)",
|
| 37 |
+
"python-dotenv (>=1.1.1,<2.0.0)"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
[build-system]
|
| 42 |
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
| 43 |
+
build-backend = "poetry.core.masonry.api"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
[tool.poetry]
|
| 47 |
+
package-mode = false
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLP Model Evaluation Script for AI vs Human Music Detection
|
| 3 |
+
==========================================================
|
| 4 |
+
|
| 5 |
+
This script evaluates the performance of the trained MLP classifier on test data.
|
| 6 |
+
It gives a complete performance report showing how well the model can distinguish
|
| 7 |
+
between AI-generated and human-composed music.
|
| 8 |
+
|
| 9 |
+
What this script does:
|
| 10 |
+
- Loads our saved/trained MLP model
|
| 11 |
+
- Tests it on held-out test data (music the model has never seen)
|
| 12 |
+
- Calculates accuracy, precision, recall, and F1-score
|
| 13 |
+
- Reports confusion statistics (true positives, true negatives, false positives, false negatives)
|
| 14 |
+
- Displays sample predictions with probabilities for transparency
|
| 15 |
+
|
| 16 |
+
Quick Start:
|
| 17 |
+
---------------------------
|
| 18 |
+
# Basic evaluation with default model path
|
| 19 |
+
python evaluate.py
|
| 20 |
+
|
| 21 |
+
# Evaluate a specific model
|
| 22 |
+
python evaluate.py --model "models/fusion/mlp_multimodal.pth"
|
| 23 |
+
|
| 24 |
+
# From code
|
| 25 |
+
from evaluate import evaluate_model
|
| 26 |
+
results = evaluate_model("models/fusion/mlp_multimodal.pth")
|
| 27 |
+
|
| 28 |
+
Performance Metrics Explained:
|
| 29 |
+
------------------------------
|
| 30 |
+
- Accuracy: Overall correctness (how many songs classified correctly)
|
| 31 |
+
- Precision: Of songs predicted as human, how many actually were human
|
| 32 |
+
- Recall: Of all human songs, how many did we correctly identify
|
| 33 |
+
- F1-Score: Balance between precision and recall (harmonic mean)
|
| 34 |
+
- Confusion stats:
|
| 35 |
+
TP = Human songs correctly identified
|
| 36 |
+
TN = AI songs correctly identified
|
| 37 |
+
FP = AI songs incorrectly labeled as human
|
| 38 |
+
FN = Human songs incorrectly labeled as AI
|
| 39 |
+
|
| 40 |
+
Expected Output:
|
| 41 |
+
----------------
|
| 42 |
+
Loading model from: models/fusion/mlp_multimodal.pth
|
| 43 |
+
Loaded dataset: (50000, 684), Labels: 50000
|
| 44 |
+
Test set size: (10000, 684)
|
| 45 |
+
Evaluating model on test set...
|
| 46 |
+
|
| 47 |
+
Sample predictions:
|
| 48 |
+
True: 1, Pred: 1, Prob: 0.8234 # Correctly identified human song
|
| 49 |
+
True: 0, Pred: 0, Prob: 0.1456 # Correctly identified AI song
|
| 50 |
+
True: 1, Pred: 0, Prob: 0.4123 # Missed a human song (false negative)
|
| 51 |
+
|
| 52 |
+
=== Evaluation Results ===
|
| 53 |
+
Test Accuracy: 87.54%
|
| 54 |
+
Test Loss: 0.3412
|
| 55 |
+
Precision: 0.8832
|
| 56 |
+
Recall: 0.8654
|
| 57 |
+
F1-Score: 0.8742
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
import argparse
|
| 61 |
+
import logging
|
| 62 |
+
import numpy as np
|
| 63 |
+
from pathlib import Path
|
| 64 |
+
|
| 65 |
+
from src.models.mlp import build_mlp, load_config
|
| 66 |
+
from src.utils.config_loader import DATASET_NPZ
|
| 67 |
+
from sklearn.model_selection import train_test_split
|
| 68 |
+
|
| 69 |
+
# Set up logging
|
| 70 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 71 |
+
logger = logging.getLogger(__name__)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def evaluate_model(model_path: str = "models/fusion/mlp_multimodal.pth"):
|
| 75 |
+
logger.info(f"Loading model from: {model_path}")
|
| 76 |
+
|
| 77 |
+
# Check if dataset exists
|
| 78 |
+
if not Path(DATASET_NPZ).exists():
|
| 79 |
+
raise FileNotFoundError(f"Dataset not found at {DATASET_NPZ}. Run train.py first.")
|
| 80 |
+
|
| 81 |
+
# Load the full dataset
|
| 82 |
+
loaded_data = np.load(DATASET_NPZ)
|
| 83 |
+
X = loaded_data["X"]
|
| 84 |
+
Y = loaded_data["Y"]
|
| 85 |
+
|
| 86 |
+
logger.info(f"Loaded dataset: {X.shape}, Labels: {len(Y)}")
|
| 87 |
+
|
| 88 |
+
# Split data (same as training)
|
| 89 |
+
from src.utils.dataset import dataset_scaler
|
| 90 |
+
data = dataset_scaler(X, Y)
|
| 91 |
+
X_test, y_test = data["test"]
|
| 92 |
+
|
| 93 |
+
logger.info(f"Test set size: {X_test.shape}")
|
| 94 |
+
|
| 95 |
+
# Load configuration
|
| 96 |
+
config = load_config("config/model_config.yml")
|
| 97 |
+
|
| 98 |
+
# Build model architecture (needed for loading weights)
|
| 99 |
+
mlp_classifier = build_mlp(input_dim=X_test.shape[1], config=config)
|
| 100 |
+
|
| 101 |
+
# Load trained model
|
| 102 |
+
mlp_classifier.load_model(model_path)
|
| 103 |
+
|
| 104 |
+
# Evaluate on test set
|
| 105 |
+
logger.info("Evaluating model on test set...")
|
| 106 |
+
test_results = mlp_classifier.evaluate(X_test, y_test)
|
| 107 |
+
|
| 108 |
+
# Get predictions for detailed analysis
|
| 109 |
+
probabilities, predictions = mlp_classifier.predict(X_test)
|
| 110 |
+
|
| 111 |
+
# Show a few sample predictions
|
| 112 |
+
for i in range(10):
|
| 113 |
+
print(f"True: {y_test[i]}, Pred: {predictions[i]}, Prob: {probabilities[i]:.4f} "
|
| 114 |
+
f"(Probability of predicted class)")
|
| 115 |
+
|
| 116 |
+
logger.info("=== Evaluation Results ===")
|
| 117 |
+
logger.info(f"Test Accuracy: {test_results['test_accuracy']:.2f}%")
|
| 118 |
+
logger.info(f"Test Loss: {test_results['test_loss']:.4f}")
|
| 119 |
+
|
| 120 |
+
# Additional statistics
|
| 121 |
+
true_positives = np.sum((y_test == 1) & (predictions == 1))
|
| 122 |
+
true_negatives = np.sum((y_test == 0) & (predictions == 0))
|
| 123 |
+
false_positives = np.sum((y_test == 0) & (predictions == 1))
|
| 124 |
+
false_negatives = np.sum((y_test == 1) & (predictions == 0))
|
| 125 |
+
|
| 126 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
| 127 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
| 128 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
| 129 |
+
|
| 130 |
+
logger.info(f"Precision: {precision:.4f}")
|
| 131 |
+
logger.info(f"Recall: {recall:.4f}")
|
| 132 |
+
logger.info(f"F1-Score: {f1_score:.4f}")
|
| 133 |
+
|
| 134 |
+
# Include all metrics in return dict
|
| 135 |
+
return {
|
| 136 |
+
"test_accuracy": test_results["test_accuracy"],
|
| 137 |
+
"test_loss": test_results["test_loss"],
|
| 138 |
+
"precision": precision,
|
| 139 |
+
"recall": recall,
|
| 140 |
+
"f1_score": f1_score,
|
| 141 |
+
"true_positives": int(true_positives),
|
| 142 |
+
"true_negatives": int(true_negatives),
|
| 143 |
+
"false_positives": int(false_positives),
|
| 144 |
+
"false_negatives": int(false_negatives)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def main():
|
| 149 |
+
"""Main evaluation function."""
|
| 150 |
+
parser = argparse.ArgumentParser(description='Evaluate Bach-or-Bot MLP classifier')
|
| 151 |
+
parser.add_argument('--model', default='models/fusion/mlp_multimodal.pth',
|
| 152 |
+
help='Path to trained model')
|
| 153 |
+
args = parser.parse_args()
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
results = evaluate_model(args.model)
|
| 157 |
+
logger.info("Evaluation completed successfully!")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Evaluation failed: {str(e)}")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
scripts/explain.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from src.musiclime.explainer import MusicLIMEExplainer
|
| 4 |
+
from src.musiclime.wrapper import MusicLIMEPredictor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def musiclime(audio_data, lyrics_text):
|
| 8 |
+
"""
|
| 9 |
+
MusicLIME wrapper for API usage.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
audio_data: Audio array (from librosa.load or similar)
|
| 13 |
+
lyrics_text: String containing lyrics
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
dict: Structured explanation results
|
| 17 |
+
"""
|
| 18 |
+
start_time = datetime.now()
|
| 19 |
+
|
| 20 |
+
# Create musiclime instances
|
| 21 |
+
explainer = MusicLIMEExplainer()
|
| 22 |
+
predictor = MusicLIMEPredictor()
|
| 23 |
+
|
| 24 |
+
# Generate explanations
|
| 25 |
+
explanation = explainer.explain_instance(
|
| 26 |
+
audio=audio_data,
|
| 27 |
+
lyrics=lyrics_text,
|
| 28 |
+
predict_fn=predictor,
|
| 29 |
+
num_samples=1000,
|
| 30 |
+
labels=(1,),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Get prediction info
|
| 34 |
+
original_prediction = explanation.predictions[0]
|
| 35 |
+
predicted_class = np.argmax(original_prediction)
|
| 36 |
+
confidence = float(np.max(original_prediction))
|
| 37 |
+
|
| 38 |
+
# Get top 10 features
|
| 39 |
+
top_features = explanation.get_explanation(label=1, num_features=10)
|
| 40 |
+
|
| 41 |
+
# Calculate runtime
|
| 42 |
+
end_time = datetime.now()
|
| 43 |
+
runtime_seconds = (end_time - start_time).total_seconds()
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"prediction": {
|
| 47 |
+
"class": int(predicted_class),
|
| 48 |
+
"class_name": "Human-Composed" if predicted_class == 1 else "AI-Generated",
|
| 49 |
+
"confidence": confidence,
|
| 50 |
+
"probabilities": original_prediction.tolist(),
|
| 51 |
+
},
|
| 52 |
+
"explanations": [
|
| 53 |
+
{
|
| 54 |
+
"rank": i + 1,
|
| 55 |
+
"modality": item["type"],
|
| 56 |
+
"feature_text": item["feature"],
|
| 57 |
+
"weight": float(item["weight"]),
|
| 58 |
+
"importance": abs(float(item["weight"])),
|
| 59 |
+
}
|
| 60 |
+
for i, item in enumerate(top_features)
|
| 61 |
+
],
|
| 62 |
+
"summary": {
|
| 63 |
+
"total_features_analyzed": len(top_features),
|
| 64 |
+
"audio_features_count": len(
|
| 65 |
+
[f for f in top_features if f["type"] == "audio"]
|
| 66 |
+
),
|
| 67 |
+
"lyrics_features_count": len(
|
| 68 |
+
[f for f in top_features if f["type"] == "lyrics"]
|
| 69 |
+
),
|
| 70 |
+
"runtime_seconds": runtime_seconds,
|
| 71 |
+
"samples_generated": 1000,
|
| 72 |
+
"timestamp": start_time.isoformat(),
|
| 73 |
+
},
|
| 74 |
+
}
|
scripts/explain_test.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from src.musiclime.explainer import MusicLIMEExplainer
|
| 7 |
+
from src.musiclime.wrapper import MusicLIMEPredictor
|
| 8 |
+
from src.musiclime.print_utils import green_bold
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def explain():
|
| 12 |
+
# Start timing and time stamp to record how long the entire explanation thingy is
|
| 13 |
+
start_time = datetime.now()
|
| 14 |
+
print(
|
| 15 |
+
green_bold(
|
| 16 |
+
f"[MusicLIME] Started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}"
|
| 17 |
+
)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Create musiclime-related instances
|
| 21 |
+
explainer = MusicLIMEExplainer()
|
| 22 |
+
predictor = MusicLIMEPredictor()
|
| 23 |
+
|
| 24 |
+
# Set the path for audio and lyrics [these are samples only - song is Silver Spring]
|
| 25 |
+
audio_path = Path("data/external/sample_2.mp3")
|
| 26 |
+
lyrics_path = Path("data/external/sample_2.txt")
|
| 27 |
+
|
| 28 |
+
# Load the audio as an object + load the lyrics as string
|
| 29 |
+
y, sr = librosa.load(audio_path)
|
| 30 |
+
lyrics_text = lyrics_path.read_text(encoding="utf-8")
|
| 31 |
+
|
| 32 |
+
# Generate explanations using musiclime
|
| 33 |
+
explanation = explainer.explain_instance(
|
| 34 |
+
audio=y,
|
| 35 |
+
lyrics=lyrics_text,
|
| 36 |
+
predict_fn=predictor,
|
| 37 |
+
num_samples=1000,
|
| 38 |
+
labels=(1,),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Get original prediction (first sample is always the orig meaning unperturbed)
|
| 42 |
+
original_prediction = explanation.predictions[0]
|
| 43 |
+
predicted_class = np.argmax(original_prediction)
|
| 44 |
+
|
| 45 |
+
# Print explanations
|
| 46 |
+
results = explanation.get_explanation(label=1, num_features=10)
|
| 47 |
+
print("\n" + "=" * 80)
|
| 48 |
+
print(
|
| 49 |
+
f"[MusicLIME] Top 10 most important features for {"Human-Composed" if predicted_class == 1 else "AI-Generated"} prediction"
|
| 50 |
+
)
|
| 51 |
+
print("=" * 80)
|
| 52 |
+
|
| 53 |
+
for i, item in enumerate(results, 1):
|
| 54 |
+
print(
|
| 55 |
+
f"#{i:2d} | {item['type']:6s} | {item['feature'][:50]:50s} | weight: {item['weight']:+.6f}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
print("=" * 80)
|
| 59 |
+
print(f"[MusicLIME] Total features analyzed: {len(results)}")
|
| 60 |
+
print("[MusicLIME] Higher absolute weights = more important for the prediction")
|
| 61 |
+
|
| 62 |
+
# End timing and timestamp
|
| 63 |
+
end_time = datetime.now()
|
| 64 |
+
total_duration = end_time - start_time
|
| 65 |
+
total_minutes = total_duration.total_seconds() / 60
|
| 66 |
+
print(f"\n[MusicLIME] Finished at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
| 67 |
+
print(
|
| 68 |
+
green_bold(
|
| 69 |
+
f"[MusicLIME] Total execution time: {total_minutes:.2f} minutes ({total_duration.total_seconds():.1f} seconds)"
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
explain()
|
scripts/explain_with_json.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from src.musiclime.explainer import MusicLIMEExplainer
|
| 7 |
+
from src.musiclime.wrapper import MusicLIMEPredictor
|
| 8 |
+
from src.musiclime.print_utils import green_bold
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def explain():
|
| 12 |
+
# Start timing and time stamp to record how long the entire explanation thingy is
|
| 13 |
+
start_time = datetime.now()
|
| 14 |
+
print(
|
| 15 |
+
green_bold(
|
| 16 |
+
f"[MusicLIME] Started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}"
|
| 17 |
+
)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Create musiclime-related instances
|
| 21 |
+
explainer = MusicLIMEExplainer()
|
| 22 |
+
predictor = MusicLIMEPredictor()
|
| 23 |
+
|
| 24 |
+
# Set the path for audio and lyrics [these are samples only - song is Silver Spring]
|
| 25 |
+
audio_path = Path("data/external/sample_2.mp3")
|
| 26 |
+
lyrics_path = Path("data/external/sample_2.txt")
|
| 27 |
+
|
| 28 |
+
# Load the audio as an object + load the lyrics as string
|
| 29 |
+
y, sr = librosa.load(audio_path)
|
| 30 |
+
lyrics_text = lyrics_path.read_text(encoding="utf-8")
|
| 31 |
+
|
| 32 |
+
# Generate explanations using musiclime
|
| 33 |
+
explanation = explainer.explain_instance(
|
| 34 |
+
audio=y,
|
| 35 |
+
lyrics=lyrics_text,
|
| 36 |
+
predict_fn=predictor,
|
| 37 |
+
num_samples=1000,
|
| 38 |
+
labels=(1,),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Get original prediction (first sample is always the orig meaning unperturbed)
|
| 42 |
+
original_prediction = explanation.predictions[0]
|
| 43 |
+
predicted_class = np.argmax(original_prediction)
|
| 44 |
+
confidence = original_prediction[predicted_class]
|
| 45 |
+
|
| 46 |
+
# Create song info from the prediction
|
| 47 |
+
song_info = {
|
| 48 |
+
"filename": "sample.mp3",
|
| 49 |
+
"duration": f"{len(y)/44100:.1f}s",
|
| 50 |
+
"original_prediction": {
|
| 51 |
+
"class": "Human-Composed" if predicted_class == 1 else "AI-Generated",
|
| 52 |
+
"confidence": float(confidence),
|
| 53 |
+
"raw_probabilities": {
|
| 54 |
+
"AI": float(original_prediction[0]),
|
| 55 |
+
"Human": float(original_prediction[1]),
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Save with prediction data
|
| 61 |
+
explanation.save_to_json(
|
| 62 |
+
filepath=f"musiclime_explanation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
| 63 |
+
song_info=song_info,
|
| 64 |
+
num_features=10,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Print explanations
|
| 68 |
+
results = explanation.get_explanation(label=1, num_features=10)
|
| 69 |
+
print("\n" + "=" * 80)
|
| 70 |
+
print(
|
| 71 |
+
f"[MusicLIME] Top 10 most important features for {"Human-Composed" if predicted_class == 1 else "AI-Generated"} prediction"
|
| 72 |
+
)
|
| 73 |
+
print("=" * 80)
|
| 74 |
+
|
| 75 |
+
for i, item in enumerate(results, 1):
|
| 76 |
+
print(
|
| 77 |
+
f"#{i:2d} | {item['type']:6s} | {item['feature'][:50]:50s} | weight: {item['weight']:+.6f}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
print("=" * 80)
|
| 81 |
+
print(f"[MusicLIME] Total features analyzed: {len(results)}")
|
| 82 |
+
print("[MusicLIME] Higher absolute weights = more important for the prediction")
|
| 83 |
+
|
| 84 |
+
# End timing and timestamp
|
| 85 |
+
end_time = datetime.now()
|
| 86 |
+
total_duration = end_time - start_time
|
| 87 |
+
total_minutes = total_duration.total_seconds() / 60
|
| 88 |
+
print(f"\n[MusicLIME] Finished at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
| 89 |
+
print(
|
| 90 |
+
green_bold(
|
| 91 |
+
f"[MusicLIME] Total execution time: {total_minutes:.2f} minutes ({total_duration.total_seconds():.1f} seconds)"
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
explain()
|
scripts/predict.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.preprocessing.preprocessor import single_preprocessing
|
| 2 |
+
from src.spectttra.spectttra_trainer import spectttra_predict
|
| 3 |
+
from src.llm2vectrain.model import load_llm2vec_model
|
| 4 |
+
from src.llm2vectrain.llm2vec_trainer import l2vec_single_train, load_pca_model
|
| 5 |
+
from src.models.mlp import build_mlp, load_config
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from src.utils.config_loader import DATASET_NPZ
|
| 8 |
+
from src.utils.dataset import instance_scaler
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def predict_pipeline(audio, lyrics: str):
|
| 16 |
+
"""
|
| 17 |
+
Predict script which includes preprocessing, feature extraction, and
|
| 18 |
+
training the MLP model for a single data sample.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
audio : audio_object
|
| 23 |
+
Audio object file
|
| 24 |
+
|
| 25 |
+
lyric : string
|
| 26 |
+
Lyric string
|
| 27 |
+
|
| 28 |
+
Returns
|
| 29 |
+
-------
|
| 30 |
+
prediction : str
|
| 31 |
+
A string result of the prediction
|
| 32 |
+
|
| 33 |
+
label : int
|
| 34 |
+
A numerical representation of the prediction
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Instantiate X and Y vectors
|
| 38 |
+
X, Y = None, None
|
| 39 |
+
|
| 40 |
+
# Instantiate LLM2Vec Model
|
| 41 |
+
llm2vec_model = load_llm2vec_model()
|
| 42 |
+
|
| 43 |
+
# Preprocess both audio and lyrics
|
| 44 |
+
audio, lyrics = single_preprocessing(audio, lyrics)
|
| 45 |
+
|
| 46 |
+
# Call the train method for both models
|
| 47 |
+
audio_features = spectttra_predict(audio)
|
| 48 |
+
lyrics_features = l2vec_single_train(llm2vec_model, lyrics)
|
| 49 |
+
|
| 50 |
+
# Reduce the lyrics using saved PCA model
|
| 51 |
+
reduced_lyrics = load_pca_model(lyrics_features)
|
| 52 |
+
|
| 53 |
+
# Scale the vectors using Z-Score
|
| 54 |
+
audio_features, reduced_lyrics = instance_scaler(audio_features, reduced_lyrics)
|
| 55 |
+
|
| 56 |
+
# Concatenate the vectors of audio_features + lyrics_features
|
| 57 |
+
results = np.concatenate([audio_features, reduced_lyrics], axis=1)
|
| 58 |
+
|
| 59 |
+
# ---- Load MLP Classifier ----
|
| 60 |
+
config = load_config("config/model_config.yml")
|
| 61 |
+
classifier = build_mlp(input_dim=results.shape[1], config=config)
|
| 62 |
+
|
| 63 |
+
# Load trained weights (make sure this path matches where you saved your model)
|
| 64 |
+
model_path = "models/mlp/mlp_multimodal.pth"
|
| 65 |
+
classifier.load_model(model_path)
|
| 66 |
+
classifier.model.eval()
|
| 67 |
+
|
| 68 |
+
# Run prediction
|
| 69 |
+
probability, prediction, label = classifier.predict_single(results)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"probability": probability,
|
| 73 |
+
"label": label,
|
| 74 |
+
"prediction": "AI-Generated" if prediction == 0 else "Human-Composed",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
# Example usage (replace with real inputs, place song inside data/raw.)
|
| 80 |
+
audio = "sample"
|
| 81 |
+
lyrics = "Some lyrics text here"
|
| 82 |
+
print(predict_pipeline(audio, lyrics))
|
scripts/train.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.preprocessing.preprocessor import dataset_read, bulk_preprocessing
|
| 2 |
+
from src.spectttra.spectttra_trainer import spectttra_train
|
| 3 |
+
from src.llm2vectrain.model import load_llm2vec_model
|
| 4 |
+
from src.llm2vectrain.llm2vec_trainer import l2vec_train
|
| 5 |
+
from src.models.mlp import build_mlp, load_config
|
| 6 |
+
|
| 7 |
+
from src.utils.config_loader import DATASET_NPZ, PCA_MODEL
|
| 8 |
+
from src.utils.dataset import dataset_scaler, dataset_splitter
|
| 9 |
+
from sklearn.decomposition import PCA
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import numpy as np
|
| 13 |
+
import logging
|
| 14 |
+
import joblib
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def train_mlp_model(data : dict):
|
| 21 |
+
"""
|
| 22 |
+
Train the MLP model with extracted features.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
data : dict{np.array}
|
| 27 |
+
A dictionary of np.arrays, containing the train/test/val split.
|
| 28 |
+
"""
|
| 29 |
+
logger.info("Starting MLP training...")
|
| 30 |
+
|
| 31 |
+
# Load MLP configuration
|
| 32 |
+
config = load_config("config/model_config.yml")
|
| 33 |
+
|
| 34 |
+
# Destructure the dictionary to get data split
|
| 35 |
+
X_train, y_train = data["train"]
|
| 36 |
+
X_val, y_val = data["val"]
|
| 37 |
+
X_test, y_test = data["test"]
|
| 38 |
+
|
| 39 |
+
# Build and train MLP
|
| 40 |
+
mlp_classifier = build_mlp(input_dim=X_train.shape[1], config=config)
|
| 41 |
+
|
| 42 |
+
# Show model summary
|
| 43 |
+
mlp_classifier.get_model_summary()
|
| 44 |
+
|
| 45 |
+
# Train the model
|
| 46 |
+
history = mlp_classifier.train(X_train, y_train, X_val, y_val)
|
| 47 |
+
|
| 48 |
+
# Load best model and evaluate on test set
|
| 49 |
+
try:
|
| 50 |
+
mlp_classifier.load_model("models/mlp/mlp_best.pth")
|
| 51 |
+
logger.info("Loaded best model for final evaluation")
|
| 52 |
+
except FileNotFoundError:
|
| 53 |
+
logger.warning("Best model not found, using current model")
|
| 54 |
+
|
| 55 |
+
# Final evaluation
|
| 56 |
+
test_results = mlp_classifier.evaluate(X_test, y_test)
|
| 57 |
+
|
| 58 |
+
# Save final model
|
| 59 |
+
mlp_classifier.save_model("models/mlp/mlp_multimodal.pth")
|
| 60 |
+
|
| 61 |
+
logger.info("MLP training completed successfully!")
|
| 62 |
+
logger.info(f"Final test accuracy: {test_results['test_accuracy']:.2f}%")
|
| 63 |
+
|
| 64 |
+
return mlp_classifier
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def train_pipeline():
|
| 68 |
+
"""
|
| 69 |
+
Training script which includes preprocessing, feature extraction, and training the MLP model.
|
| 70 |
+
|
| 71 |
+
The train pipeline saves the train dataset in an .npz format.
|
| 72 |
+
|
| 73 |
+
Parameters
|
| 74 |
+
----------
|
| 75 |
+
None
|
| 76 |
+
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
None
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
# Instantiate X and Y vectors
|
| 83 |
+
X, Y = None, None
|
| 84 |
+
|
| 85 |
+
dataset_path = Path(DATASET_NPZ)
|
| 86 |
+
|
| 87 |
+
if dataset_path.exists():
|
| 88 |
+
logger.info("Training dataset already exists. Loading file...")
|
| 89 |
+
|
| 90 |
+
loaded_data = np.load(DATASET_NPZ)
|
| 91 |
+
X = loaded_data["X"]
|
| 92 |
+
Y = loaded_data["Y"]
|
| 93 |
+
else:
|
| 94 |
+
logger.info("Training dataset does not exist. Processing data...")
|
| 95 |
+
# Get batches from dataset and return full Y labels
|
| 96 |
+
batches, Y = dataset_read(batch_size=500)
|
| 97 |
+
batch_count = 1
|
| 98 |
+
|
| 99 |
+
# Instantiate LLM2Vec and PCA model
|
| 100 |
+
llm2vec_model = load_llm2vec_model()
|
| 101 |
+
|
| 102 |
+
# Preallocate spaces for both audio and lyric vectors to reduce memory overhead
|
| 103 |
+
audio_vectors = np.zeros((len(Y), 384), dtype=np.float32)
|
| 104 |
+
lyric_vectors = np.zeros((len(Y), 4096), dtype=np.float32)
|
| 105 |
+
|
| 106 |
+
start_idx = 0
|
| 107 |
+
for batch in batches:
|
| 108 |
+
|
| 109 |
+
logger.info(f"Bulk Preprocessing - Batch {batch_count}.")
|
| 110 |
+
audio, lyrics = bulk_preprocessing(batch, batch_count)
|
| 111 |
+
batch_count += 1
|
| 112 |
+
|
| 113 |
+
# Call the train methods for both SpecTTTra and LLM2Vec
|
| 114 |
+
logger.info("Starting SpecTTTra feature extraction...")
|
| 115 |
+
audio_features = spectttra_train(audio)
|
| 116 |
+
|
| 117 |
+
logger.info("Starting LLM2Vec feature extraction...")
|
| 118 |
+
lyrics_features = l2vec_train(llm2vec_model, lyrics)
|
| 119 |
+
|
| 120 |
+
batch_size = audio_features.shape[0]
|
| 121 |
+
|
| 122 |
+
# Store the results on preallocated spaces
|
| 123 |
+
audio_vectors[start_idx:start_idx + batch_size, :] = audio_features
|
| 124 |
+
lyric_vectors[start_idx:start_idx + batch_size, :] = lyrics_features
|
| 125 |
+
|
| 126 |
+
# Delete stored instance for next batch to remove overhead
|
| 127 |
+
del audio, lyrics, audio_features, lyrics_features
|
| 128 |
+
|
| 129 |
+
# Run standard scaling on audio and lyrics separately
|
| 130 |
+
logger.info("Running standard scaling for audio and lyrics...")
|
| 131 |
+
audio_vectors, lyric_vectors = dataset_scaler(audio_vectors, lyric_vectors)
|
| 132 |
+
|
| 133 |
+
# Start training the PCA to the collected lyrics features
|
| 134 |
+
logger.info("PCA Training on lyric vectors...")
|
| 135 |
+
pca = PCA(n_components=256, svd_solver="randomized", random_state=42)
|
| 136 |
+
lyric_vectors = pca.fit_transform(lyric_vectors)
|
| 137 |
+
|
| 138 |
+
# Save the trained PCA model
|
| 139 |
+
joblib.dump(pca, "models/fusion/pca.pkl")
|
| 140 |
+
|
| 141 |
+
# Concatenate audio features and reduced lyrics features
|
| 142 |
+
X = np.concatenate([audio_vectors, lyric_vectors], axis=1)
|
| 143 |
+
logger.info(f"Audio and Lyrics Concatenated. Final features shape: {X.shape}")
|
| 144 |
+
|
| 145 |
+
# Convert label list into np.array
|
| 146 |
+
Y = np.array(Y)
|
| 147 |
+
|
| 148 |
+
# Save both X and Y to an .npz file for easier loading
|
| 149 |
+
logger.info("Saving dataset for future testing...")
|
| 150 |
+
np.savez(DATASET_NPZ, X=X, Y=Y)
|
| 151 |
+
|
| 152 |
+
# Do data splitting
|
| 153 |
+
data = dataset_splitter(X, Y)
|
| 154 |
+
|
| 155 |
+
logger.info("Starting MLP training...")
|
| 156 |
+
train_mlp_model(data)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
train_pipeline()
|
src/__init__.py
ADDED
|
File without changes
|
src/features/__init__.py
ADDED
|
File without changes
|
src/features/llm2vec.py
ADDED
|
File without changes
|
src/features/spectttra.py
ADDED
|
File without changes
|
src/llm2vectrain/__init__.py
ADDED
|
File without changes
|
src/llm2vectrain/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
src/llm2vectrain/__pycache__/access_token.cpython-312.pyc
ADDED
|
Binary file (208 Bytes). View file
|
|
|
src/llm2vectrain/__pycache__/llm2vec_trainer.cpython-312.pyc
ADDED
|
Binary file (7.51 kB). View file
|
|
|
src/llm2vectrain/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
src/llm2vectrain/config.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
access_token = os.getenv("HF_TOKEN")
|
src/llm2vectrain/llm2vec_trainer.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.decomposition import IncrementalPCA
|
| 2 |
+
from sklearn.preprocessing import StandardScaler
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
import joblib
|
| 10 |
+
|
| 11 |
+
# Initialize PCA and StandardScaler globally for training
|
| 12 |
+
_pca_trainer = None
|
| 13 |
+
|
| 14 |
+
class SimplePCATrainer:
|
| 15 |
+
"""
|
| 16 |
+
A simple PCA trainer that uses IncrementalPCA to fit data in batches.
|
| 17 |
+
It saves checkpoints every 5 batches and can save the final model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
None
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
None
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
pca: The IncrementalPCA model.
|
| 27 |
+
scaler: StandardScaler for normalizing data.
|
| 28 |
+
fitted: Boolean indicating if the model has been initialized.
|
| 29 |
+
batch_count_pca: Counter for the number of batches processed.
|
| 30 |
+
|
| 31 |
+
Methods:
|
| 32 |
+
process_batch(vectors): Processes a batch of vectors, fits the PCA model incrementally.
|
| 33 |
+
save_final(model_path): Saves the final PCA model to the specified path.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Initialize the trainer
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.pca = None
|
| 39 |
+
self.scaler = StandardScaler()
|
| 40 |
+
self.fitted = False
|
| 41 |
+
self.batch_count_pca = 0
|
| 42 |
+
|
| 43 |
+
def _determine_optimal_components(self, vectors):
|
| 44 |
+
"""
|
| 45 |
+
Determine the optimal number of PCA components to retain 95% variance.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vectors: The input data to analyze.
|
| 49 |
+
Returns:
|
| 50 |
+
n_components: The optimal number of components.
|
| 51 |
+
"""
|
| 52 |
+
temp_pca = IncrementalPCA()
|
| 53 |
+
temp_pca.fit(vectors)
|
| 54 |
+
cumsum_var = np.cumsum(temp_pca.explained_variance_ratio_)
|
| 55 |
+
n_comp_95 = np.argmax(cumsum_var >= 0.95) + 1
|
| 56 |
+
return min(n_comp_95, vectors.shape[1] // 2)
|
| 57 |
+
|
| 58 |
+
def process_batch(self, vectors):
|
| 59 |
+
"""
|
| 60 |
+
Process a batch of vectors, fitting the PCA model incrementally.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
vectors: The input data batch to process.
|
| 64 |
+
Returns:
|
| 65 |
+
reduced_vectors: The PCA-transformed data.
|
| 66 |
+
|
| 67 |
+
Note: This method saves a checkpoint every 5 batches.
|
| 68 |
+
"""
|
| 69 |
+
if not self.fitted:
|
| 70 |
+
# First batch - initialize everything
|
| 71 |
+
n_components = self._determine_optimal_components(vectors)
|
| 72 |
+
self.pca = IncrementalPCA(n_components=n_components, batch_size=1000)
|
| 73 |
+
self.scaler.fit(vectors)
|
| 74 |
+
self.fitted = True
|
| 75 |
+
print(f"Initialized PCA with {n_components} components")
|
| 76 |
+
|
| 77 |
+
# Process batch
|
| 78 |
+
vectors_scaled = self.scaler.transform(vectors)
|
| 79 |
+
self.pca.partial_fit(vectors_scaled)
|
| 80 |
+
reduced_vectors = self.pca.transform(vectors_scaled)
|
| 81 |
+
|
| 82 |
+
self.batch_count_pca += 1
|
| 83 |
+
|
| 84 |
+
# Save checkpoint every 5 batches
|
| 85 |
+
if self.batch_count_pca % 5 == 0:
|
| 86 |
+
os.makedirs("pca_checkpoints", exist_ok=True)
|
| 87 |
+
with open(f"pca_checkpoints/checkpoint_batch_{self.batch_count_pca}.pkl", 'wb') as f:
|
| 88 |
+
pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
|
| 89 |
+
print(f"Saved checkpoint at batch {self.batch_count_pca}")
|
| 90 |
+
|
| 91 |
+
print(f"Processed batch {self.batch_count_pca}, shape: {vectors.shape} -> {reduced_vectors.shape}")
|
| 92 |
+
return reduced_vectors
|
| 93 |
+
|
| 94 |
+
def save_final(self, model_path):
|
| 95 |
+
"""
|
| 96 |
+
Save the final PCA model to the specified path.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
model_path: The file path to save the PCA model.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
None
|
| 103 |
+
|
| 104 |
+
Note: Change the model path as needed in the data_config.yml file.
|
| 105 |
+
"""
|
| 106 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 107 |
+
with open(model_path, 'wb') as f:
|
| 108 |
+
pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
|
| 109 |
+
print(f"Final model saved to {model_path}. Total variance explained: {np.sum(self.pca.explained_variance_ratio_):.4f}")
|
| 110 |
+
|
| 111 |
+
## For Single Input
|
| 112 |
+
def load_pca_model(vectors, model_path="models/fusion/pca.pkl"):
|
| 113 |
+
"""
|
| 114 |
+
Load a pre-trained PCA model and transform the input vectors.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
vectors: The input data to transform.
|
| 118 |
+
model_path: The file path of the pre-trained PCA model.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
output: The PCA-transformed data.
|
| 122 |
+
|
| 123 |
+
Note: Change the model path as needed in the data_config.yml file (or set the path file as shown above). Can be used for the main program.
|
| 124 |
+
"""
|
| 125 |
+
model_path = Path(model_path)
|
| 126 |
+
pca = joblib.load(model_path)
|
| 127 |
+
return pca.transform(vectors)
|
| 128 |
+
|
| 129 |
+
def l2vec_single_train(l2v, lyrics):
|
| 130 |
+
"""
|
| 131 |
+
Encode a single lyric string using the provided LLM2Vec model.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
l2v: The LLM2Vec model for encoding lyrics.
|
| 135 |
+
lyrics: A single lyric string to encode.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
vectors: The vector representation of the lyrics.
|
| 139 |
+
|
| 140 |
+
"""
|
| 141 |
+
vectors = l2v.encode([lyrics]).detach().cpu().numpy()
|
| 142 |
+
return vectors
|
| 143 |
+
|
| 144 |
+
# For Batch Processing
|
| 145 |
+
def l2vec_train(l2v, lyrics_list):
|
| 146 |
+
"""
|
| 147 |
+
Encode a list of lyric strings using the provided LLM2Vec model.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
l2v: The LLM2Vec model for encoding lyrics.
|
| 151 |
+
lyrics_list: A list of lyric strings to encode.
|
| 152 |
+
Returns:
|
| 153 |
+
vectors: The encoded vector representations of the lyrics.
|
| 154 |
+
|
| 155 |
+
Note: This function only encodes the lyrics and does not apply PCA reduction. The PCA reduction can be applied separately in the train.py module.
|
| 156 |
+
"""
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
vectors = l2v.encode(lyrics_list) # lyrics_list: list of strings
|
| 159 |
+
return vectors
|
src/llm2vectrain/model.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm2vec import LLM2Vec
|
| 2 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
from src.llm2vectrain.config import access_token
|
| 5 |
+
import torch
|
| 6 |
+
from torchao.quantization import quantize_, Int8WeightOnlyConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_llm2vec_model():
|
| 10 |
+
|
| 11 |
+
model_id = "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp"
|
| 12 |
+
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 14 |
+
model_id, padding=True, truncation=True, max_length=512
|
| 15 |
+
)
|
| 16 |
+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
| 17 |
+
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
# GPU path: use bf16 for speed
|
| 20 |
+
model = AutoModel.from_pretrained(
|
| 21 |
+
model_id,
|
| 22 |
+
trust_remote_code=True,
|
| 23 |
+
config=config,
|
| 24 |
+
torch_dtype=torch.bfloat16,
|
| 25 |
+
device_map="cuda",
|
| 26 |
+
token=access_token,
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
# CPU path: use float32 first, then quantize
|
| 30 |
+
model = AutoModel.from_pretrained(
|
| 31 |
+
model_id,
|
| 32 |
+
trust_remote_code=True,
|
| 33 |
+
config=config,
|
| 34 |
+
torch_dtype=torch.float32, # quantization requires fp32
|
| 35 |
+
device_map="cpu",
|
| 36 |
+
token=access_token,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from torchao.quantization import quantize_
|
| 41 |
+
|
| 42 |
+
print("[INFO] Applying torchao quantization for CPU...")
|
| 43 |
+
quant_config = Int8WeightOnlyConfig(group_size=None)
|
| 44 |
+
print("[INFO] Applying torchao quantization with Int8WeightOnlyConfig...")
|
| 45 |
+
quantize_(model, quant_config)
|
| 46 |
+
except ImportError:
|
| 47 |
+
print("[WARNING] torchao not installed. Run: pip install torchao")
|
| 48 |
+
print("[WARNING] Falling back to non-quantized CPU model.")
|
| 49 |
+
|
| 50 |
+
l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)
|
| 51 |
+
return l2v
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (139 Bytes). View file
|
|
|
src/models/__pycache__/mlp.cpython-312.pyc
ADDED
|
Binary file (32.2 kB). View file
|
|
|
src/models/fusion.py
ADDED
|
File without changes
|
src/models/mlp.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLP Classifier for AI vs Human Music Detection
|
| 3 |
+
==============================================
|
| 4 |
+
|
| 5 |
+
This is our main classifier that determines if a piece of music was created by AI or by humans.
|
| 6 |
+
|
| 7 |
+
What it does:
|
| 8 |
+
- Takes combined features from LLM2Vec (text) + Spectra (audio)
|
| 9 |
+
- Feeds them through a neural network
|
| 10 |
+
- Outputs: "This sounds like AI" or "This sounds human"
|
| 11 |
+
|
| 12 |
+
Quick Start:
|
| 13 |
+
---------------------------
|
| 14 |
+
# 1. Load settings from config file
|
| 15 |
+
config = load_config("config/model_config.yml")
|
| 16 |
+
|
| 17 |
+
# 2. Combine LLM2Vec and Spectra features
|
| 18 |
+
combined_features = np.concatenate([llm2vec_features, spectra_features], axis=1)
|
| 19 |
+
|
| 20 |
+
# 3. Create classifier
|
| 21 |
+
classifier = MLPClassifier(input_dim=combined_features.shape[1], config=config)
|
| 22 |
+
|
| 23 |
+
# 4. Train it
|
| 24 |
+
history = classifier.train(X_train, y_train, X_val, y_val)
|
| 25 |
+
|
| 26 |
+
# 5. Test it
|
| 27 |
+
results = classifier.evaluate(X_test, y_test)
|
| 28 |
+
|
| 29 |
+
# 6. Use it for new predictions
|
| 30 |
+
probabilities, predictions = classifier.predict(new_music_features)
|
| 31 |
+
|
| 32 |
+
How the Neural Network Works:
|
| 33 |
+
-----------------------------
|
| 34 |
+
Input → Hidden Layers → Output
|
| 35 |
+
↓ ↓ ↓
|
| 36 |
+
Features Processing AI/Human
|
| 37 |
+
(LLM2Vec + (Multiple (0 or 1)
|
| 38 |
+
Spectra) layers)
|
| 39 |
+
|
| 40 |
+
The network learns patterns that help distinguish AI-generated music from human music.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from typing import Dict, Tuple
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from tqdm import tqdm
|
| 46 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 47 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 48 |
+
|
| 49 |
+
import logging
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn as nn
|
| 52 |
+
import torch.optim as optim
|
| 53 |
+
import numpy as np
|
| 54 |
+
import yaml
|
| 55 |
+
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MLPModel(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
The actual neural network that does the AI vs Human classification.
|
| 62 |
+
|
| 63 |
+
What happens inside:
|
| 64 |
+
1. Takes the combined LLM2Vec + Spectra features
|
| 65 |
+
2. Passes them through multiple hidden layers (each layer learns different patterns)
|
| 66 |
+
3. Each layer applies: processing → normalization → activation → dropout
|
| 67 |
+
4. Final layer outputs a probability (0-1) where closer to 1 = "more human-like"
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
input_dim (int): How many features we have total (LLM2Vec size + Spectra size)
|
| 71 |
+
config (Dict): Settings from the YAML file that specify:
|
| 72 |
+
- "hidden_layers": How many neurons in each layer [128, 64, 32]
|
| 73 |
+
- "dropout": How much to randomly "forget" to prevent overfitting [0.3, 0.5, 0.2]
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, input_dim: int, config: Dict):
|
| 77 |
+
"""
|
| 78 |
+
Build the neural network architecture based on our config file.
|
| 79 |
+
"""
|
| 80 |
+
super(MLPModel, self).__init__()
|
| 81 |
+
|
| 82 |
+
self.hidden_layers = config["hidden_layers"]
|
| 83 |
+
self.dropout_rates = config["dropout"]
|
| 84 |
+
|
| 85 |
+
# Build layers with batch normalization
|
| 86 |
+
layers = []
|
| 87 |
+
prev_dim = input_dim
|
| 88 |
+
|
| 89 |
+
# First, normalize the input features (makes training more stable)
|
| 90 |
+
layers.append(nn.BatchNorm1d(input_dim))
|
| 91 |
+
|
| 92 |
+
# Build hidden layers
|
| 93 |
+
for i, units in enumerate(self.hidden_layers):
|
| 94 |
+
# Main processing layer
|
| 95 |
+
layers.append(nn.Linear(prev_dim, units))
|
| 96 |
+
|
| 97 |
+
# Normalize outputs (helps with training stability)
|
| 98 |
+
|
| 99 |
+
# Batch normalization
|
| 100 |
+
layers.append(nn.BatchNorm1d(units))
|
| 101 |
+
|
| 102 |
+
# Activation function (allows network to learn complex patterns)
|
| 103 |
+
layers.append(nn.LeakyReLU(negative_slope=0.01))
|
| 104 |
+
|
| 105 |
+
# Randomly "forget" some connections to prevent overfitting
|
| 106 |
+
dropout_rate = self.dropout_rates[i] if i < len(self.dropout_rates) else 0.5
|
| 107 |
+
if dropout_rate > 0:
|
| 108 |
+
layers.append(nn.Dropout(dropout_rate))
|
| 109 |
+
|
| 110 |
+
prev_dim = units
|
| 111 |
+
|
| 112 |
+
# Final output layer: gives us the AI vs Human probability
|
| 113 |
+
layers.append(nn.Linear(prev_dim, 1))
|
| 114 |
+
# Squeezes output between 0 and 1
|
| 115 |
+
layers.append(nn.Sigmoid())
|
| 116 |
+
|
| 117 |
+
self.network = nn.Sequential(*layers)
|
| 118 |
+
self._initialize_weights()
|
| 119 |
+
|
| 120 |
+
logger.info(
|
| 121 |
+
f"Built MLP with {len(self.hidden_layers)} hidden layers: {self.hidden_layers}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def _initialize_weights(self):
|
| 125 |
+
"""
|
| 126 |
+
Set up the starting weights for training.
|
| 127 |
+
|
| 128 |
+
Uses Xavier initialization - a way to set initial weights
|
| 129 |
+
so the network trains better from the start.
|
| 130 |
+
"""
|
| 131 |
+
for layer in self.network:
|
| 132 |
+
if isinstance(layer, nn.Linear):
|
| 133 |
+
nn.init.xavier_uniform_(layer.weight, gain=0.5)
|
| 134 |
+
nn.init.zeros_(layer.bias)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Process input features through the network to get predictions.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
x: Our combined music features (LLM2Vec + Spectra)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Probability that the music is human-composed (0 to 1)
|
| 145 |
+
"""
|
| 146 |
+
return self.network(x)
|
| 147 |
+
|
| 148 |
+
def mixup(X, y, alpha=0.2):
|
| 149 |
+
"""Apply MixUp augmentation to a batch."""
|
| 150 |
+
if alpha <= 0:
|
| 151 |
+
return X, y, y, 1.0 # no mixing
|
| 152 |
+
|
| 153 |
+
lam = np.random.beta(alpha, alpha)
|
| 154 |
+
batch_size = X.size(0)
|
| 155 |
+
index = torch.randperm(batch_size).to(X.device)
|
| 156 |
+
|
| 157 |
+
mixed_X = lam * X + (1 - lam) * X[index]
|
| 158 |
+
y_a, y_b = y, y[index]
|
| 159 |
+
return mixed_X, y_a, y_b, lam
|
| 160 |
+
|
| 161 |
+
def mixup_loss(criterion, pred, y_a, y_b, lam):
|
| 162 |
+
"""Compute MixUp loss."""
|
| 163 |
+
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class MLPClassifier:
|
| 167 |
+
"""
|
| 168 |
+
The complete music classifier system that wraps everything together.
|
| 169 |
+
|
| 170 |
+
This handles all the training, testing, and prediction logic.
|
| 171 |
+
|
| 172 |
+
What it manages:
|
| 173 |
+
- The neural network model
|
| 174 |
+
- Training process (with smart features like early stopping)
|
| 175 |
+
- Making predictions on new music
|
| 176 |
+
- Saving/loading trained models
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, input_dim: int, config: Dict):
|
| 180 |
+
"""
|
| 181 |
+
Set up the complete classification system.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
input_dim (int): Total number of features (LLM2Vec + Spectra combined)
|
| 185 |
+
config (Dict): All our settings from the YAML config file
|
| 186 |
+
|
| 187 |
+
This creates:
|
| 188 |
+
- The neural network
|
| 189 |
+
- The training optimizer (Adam - good for most cases)
|
| 190 |
+
- Learning rate scheduler (automatically adjusts learning speed)
|
| 191 |
+
- Loss function (measures how wrong our predictions are)
|
| 192 |
+
"""
|
| 193 |
+
self.config = config
|
| 194 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 195 |
+
|
| 196 |
+
# Build the neural network
|
| 197 |
+
self.model = MLPModel(input_dim, config).to(self.device)
|
| 198 |
+
|
| 199 |
+
# Optimizer: the algorithm that improves the network during training
|
| 200 |
+
self.optimizer = optim.Adam(
|
| 201 |
+
self.model.parameters(),
|
| 202 |
+
lr=config.get("learning_rate", 0.001),
|
| 203 |
+
weight_decay=config.get("weight_decay", 0.01),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Scheduler: automatically reduces learning rate if we get stuck
|
| 207 |
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 208 |
+
self.optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-7
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Loss function: measures how wrong our predictions are
|
| 212 |
+
self.criterion = nn.BCELoss()
|
| 213 |
+
|
| 214 |
+
self.is_trained = False
|
| 215 |
+
|
| 216 |
+
logger.info(f"Using device: {self.device}")
|
| 217 |
+
logger.info(
|
| 218 |
+
f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def _create_data_loader(
|
| 222 |
+
self, X: np.ndarray, y: np.ndarray, shuffle: bool = True
|
| 223 |
+
) -> DataLoader:
|
| 224 |
+
"""
|
| 225 |
+
Convert the numpy arrays into batches that PyTorch can process.
|
| 226 |
+
"""
|
| 227 |
+
X_tensor = torch.FloatTensor(X)
|
| 228 |
+
y_tensor = torch.FloatTensor(y).unsqueeze(1)
|
| 229 |
+
|
| 230 |
+
dataset = TensorDataset(X_tensor, y_tensor)
|
| 231 |
+
return DataLoader(
|
| 232 |
+
dataset, batch_size=self.config["batch_size"], shuffle=shuffle
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def train(
|
| 236 |
+
self,
|
| 237 |
+
X_train: np.ndarray,
|
| 238 |
+
y_train: np.ndarray,
|
| 239 |
+
X_val: np.ndarray,
|
| 240 |
+
y_val: np.ndarray,
|
| 241 |
+
) -> Dict:
|
| 242 |
+
"""
|
| 243 |
+
Train the model to recognize AI vs Human music patterns.
|
| 244 |
+
|
| 245 |
+
The model learns by:
|
| 246 |
+
1. Looking at training examples (music + labels)
|
| 247 |
+
2. Making predictions
|
| 248 |
+
3. Seeing how wrong it was
|
| 249 |
+
4. Adjusting its parameters to do better
|
| 250 |
+
5. Repeating thousands of times
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
X_train: Training music features (LLM2Vec + Spectra combined)
|
| 254 |
+
y_train: Training labels (0 = AI-generated, 1 = human-composed)
|
| 255 |
+
X_val: Validation features (used to check if we're overfitting)
|
| 256 |
+
y_val: Validation labels
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Dict: Training history showing how loss and accuracy changed over time
|
| 260 |
+
|
| 261 |
+
Smart features included:
|
| 262 |
+
- Early stopping: stops training if validation performance gets worse
|
| 263 |
+
- Learning rate scheduling: slows down learning if we get stuck
|
| 264 |
+
- Gradient clipping: prevents training from going crazy
|
| 265 |
+
- Progress bars: so we can see what's happening. imported tqdm for this LMAO
|
| 266 |
+
"""
|
| 267 |
+
logger.info("Starting MLP training...")
|
| 268 |
+
|
| 269 |
+
# Prepare the data for training
|
| 270 |
+
train_loader = self._create_data_loader(X_train, y_train, shuffle=True)
|
| 271 |
+
val_loader = self._create_data_loader(X_val, y_val, shuffle=False)
|
| 272 |
+
|
| 273 |
+
# Track training progress
|
| 274 |
+
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
|
| 275 |
+
|
| 276 |
+
# Early stopping variables
|
| 277 |
+
best_val_loss = float("inf")
|
| 278 |
+
patience_counter = 0
|
| 279 |
+
patience = self.config["patience"]
|
| 280 |
+
|
| 281 |
+
# Main training loop
|
| 282 |
+
for epoch in range(self.config["epochs"]):
|
| 283 |
+
# Training phase - model learns from training data
|
| 284 |
+
self.model.train()
|
| 285 |
+
train_loss = 0.0
|
| 286 |
+
train_correct = 0
|
| 287 |
+
train_total = 0
|
| 288 |
+
|
| 289 |
+
train_pbar = tqdm(
|
| 290 |
+
train_loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} [Train]"
|
| 291 |
+
)
|
| 292 |
+
for batch_X, batch_y in train_pbar:
|
| 293 |
+
batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
|
| 294 |
+
|
| 295 |
+
# Forward pass: make predictions
|
| 296 |
+
self.optimizer.zero_grad()
|
| 297 |
+
|
| 298 |
+
# Adding training augmentation if mixup value > 0
|
| 299 |
+
if self.config.get("mixup_alpha", 0) > 0:
|
| 300 |
+
mixed_X, y_a, y_b, lam = MLPModel.mixup(
|
| 301 |
+
batch_X, batch_y, alpha=self.config["mixup_alpha"]
|
| 302 |
+
)
|
| 303 |
+
outputs = self.model(mixed_X)
|
| 304 |
+
loss = MLPModel.mixup_loss(self.criterion, outputs, y_a, y_b, lam)
|
| 305 |
+
else:
|
| 306 |
+
outputs = self.model(batch_X)
|
| 307 |
+
loss = self.criterion(outputs, batch_y)
|
| 308 |
+
|
| 309 |
+
# Backward pass: learn from mistakes
|
| 310 |
+
loss.backward()
|
| 311 |
+
|
| 312 |
+
# Prevent gradients from getting too large (helps stability)
|
| 313 |
+
if self.config.get("gradient_clipping"):
|
| 314 |
+
torch.nn.utils.clip_grad_norm_(
|
| 315 |
+
self.model.parameters(), self.config["gradient_clipping"]
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.optimizer.step()
|
| 319 |
+
|
| 320 |
+
# Track statistics
|
| 321 |
+
train_loss += loss.item()
|
| 322 |
+
# Convert probabilities to 0/1 predictions
|
| 323 |
+
predicted = (outputs > 0.5).float()
|
| 324 |
+
train_total += batch_y.size(0)
|
| 325 |
+
train_correct += (predicted == batch_y).sum().item()
|
| 326 |
+
|
| 327 |
+
# Update progress bar
|
| 328 |
+
train_pbar.set_postfix(
|
| 329 |
+
{
|
| 330 |
+
"Loss": f"{loss.item():.4f}",
|
| 331 |
+
"Acc": f"{100.*train_correct/train_total:.2f}%",
|
| 332 |
+
}
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Calculate epoch averages
|
| 336 |
+
avg_train_loss = train_loss / len(train_loader)
|
| 337 |
+
train_acc = 100.0 * train_correct / train_total
|
| 338 |
+
|
| 339 |
+
history["train_loss"].append(avg_train_loss)
|
| 340 |
+
history["train_acc"].append(train_acc)
|
| 341 |
+
|
| 342 |
+
# Validation phase - check how well we do on unseen data
|
| 343 |
+
val_loss, val_acc = self._validate(val_loader)
|
| 344 |
+
history["val_loss"].append(val_loss)
|
| 345 |
+
history["val_acc"].append(val_acc)
|
| 346 |
+
|
| 347 |
+
# Adjust learning rate if needed
|
| 348 |
+
self.scheduler.step(val_loss)
|
| 349 |
+
|
| 350 |
+
logger.info(
|
| 351 |
+
f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
|
| 352 |
+
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Early stopping logic - save best model and stop if no improvement
|
| 356 |
+
if val_loss < best_val_loss:
|
| 357 |
+
best_val_loss = val_loss
|
| 358 |
+
patience_counter = 0
|
| 359 |
+
self.is_trained = True
|
| 360 |
+
# Save the best version
|
| 361 |
+
self.save_model("models/mlp/mlp_best.pth")
|
| 362 |
+
else:
|
| 363 |
+
patience_counter += 1
|
| 364 |
+
|
| 365 |
+
if patience_counter >= patience:
|
| 366 |
+
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
self.is_trained = True
|
| 370 |
+
logger.info("MLP training completed!")
|
| 371 |
+
return history
|
| 372 |
+
|
| 373 |
+
def _validate(self, val_loader: DataLoader) -> Tuple[float, float]:
|
| 374 |
+
"""
|
| 375 |
+
Test how well the model performs on validation/test data.
|
| 376 |
+
|
| 377 |
+
This runs the model in "evaluation mode" - no learning happens,
|
| 378 |
+
we just check how accurate our predictions are.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
Average loss and accuracy percentage
|
| 382 |
+
"""
|
| 383 |
+
# Switch to evaluation mode
|
| 384 |
+
self.model.eval()
|
| 385 |
+
val_loss = 0.0
|
| 386 |
+
val_correct = 0
|
| 387 |
+
val_total = 0
|
| 388 |
+
|
| 389 |
+
# Don't track gradients (saves memory and time)
|
| 390 |
+
with torch.no_grad():
|
| 391 |
+
for batch_X, batch_y in val_loader:
|
| 392 |
+
batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
|
| 393 |
+
|
| 394 |
+
outputs = self.model(batch_X)
|
| 395 |
+
loss = self.criterion(outputs, batch_y)
|
| 396 |
+
|
| 397 |
+
val_loss += loss.item()
|
| 398 |
+
# Convert to binary predictions
|
| 399 |
+
predicted = (outputs > 0.5).float()
|
| 400 |
+
val_total += batch_y.size(0)
|
| 401 |
+
val_correct += (predicted == batch_y).sum().item()
|
| 402 |
+
|
| 403 |
+
avg_val_loss = val_loss / len(val_loader)
|
| 404 |
+
val_acc = 100.0 * val_correct / val_total
|
| 405 |
+
|
| 406 |
+
return avg_val_loss, val_acc
|
| 407 |
+
|
| 408 |
+
def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 409 |
+
"""
|
| 410 |
+
Use the trained model to classify new music as AI-generated or human-composed.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
X: Music features (LLM2Vec + Spectra combined) for songs we want to classify
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
probabilities: How confident the model is (0.0 to 1.0, higher = more human-like)
|
| 417 |
+
predictions: Binary classifications (0 = AI-generated, 1 = human-composed)
|
| 418 |
+
|
| 419 |
+
Example:
|
| 420 |
+
probs, preds = classifier.predict(new_song_features)
|
| 421 |
+
if preds[0] == 1:
|
| 422 |
+
print(f"This sounds human-composed (confidence: {probs[0]:.2f})")
|
| 423 |
+
else:
|
| 424 |
+
print(f"This sounds AI-generated (confidence: {1-probs[0]:.2f})")
|
| 425 |
+
"""
|
| 426 |
+
self.model.eval()
|
| 427 |
+
# Create dummy labels since we don't know the true answers
|
| 428 |
+
data_loader = self._create_data_loader(X, np.zeros(len(X)), shuffle=False)
|
| 429 |
+
|
| 430 |
+
probabilities = []
|
| 431 |
+
|
| 432 |
+
with torch.no_grad():
|
| 433 |
+
for batch_X, _ in data_loader:
|
| 434 |
+
batch_X = batch_X.to(self.device)
|
| 435 |
+
outputs = self.model(batch_X)
|
| 436 |
+
probabilities.extend(outputs.cpu().numpy())
|
| 437 |
+
|
| 438 |
+
probabilities = np.array(probabilities).flatten()
|
| 439 |
+
# Threshold at 0.5
|
| 440 |
+
predictions = (probabilities > 0.5).astype(int)
|
| 441 |
+
|
| 442 |
+
return probabilities, predictions
|
| 443 |
+
|
| 444 |
+
def predict_single(self, features: np.ndarray) -> Tuple[float, int, str]:
|
| 445 |
+
"""
|
| 446 |
+
Predict whether a single song is AI-generated or human-composed.
|
| 447 |
+
|
| 448 |
+
This method is optimized for predicting one song at a time.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
features: Music features for ONE song (LLM2Vec + Spectra combined)
|
| 452 |
+
Should be 1D array with shape (feature_dim,)
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
probability: Confidence score (0.0 to 1.0, higher = more human-like)
|
| 456 |
+
prediction: Binary classification (0 = AI-generated, 1 = human-composed)
|
| 457 |
+
label: Human-readable label ("AI-Generated" or "Human-Composed")
|
| 458 |
+
|
| 459 |
+
Example:
|
| 460 |
+
# For a single song
|
| 461 |
+
single_song_features = np.array([0.1, 0.5, 0.3, ...])
|
| 462 |
+
prob, pred, label = classifier.predict_single(single_song_features)
|
| 463 |
+
|
| 464 |
+
print(f"Prediction: {label}")
|
| 465 |
+
print(f"Confidence: {prob:.3f}")
|
| 466 |
+
|
| 467 |
+
if pred == 1:
|
| 468 |
+
print(f"This sounds {prob:.1%} human-composed")
|
| 469 |
+
else:
|
| 470 |
+
print(f"This sounds {(1-prob):.1%} AI-generated")
|
| 471 |
+
"""
|
| 472 |
+
if not self.is_trained:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
"Model must be trained before making predictions. Call train() first."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Ensure input is the right shape
|
| 478 |
+
if features.ndim == 1:
|
| 479 |
+
features = features.reshape(1, -1) # Convert to batch of size 1
|
| 480 |
+
elif features.shape[0] != 1:
|
| 481 |
+
raise ValueError(
|
| 482 |
+
f"Expected features for 1 song, got {features.shape[0]} songs. Use predict_batch() instead."
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Use the existing predict method
|
| 486 |
+
probabilities, predictions = self.predict(features)
|
| 487 |
+
|
| 488 |
+
# Extract single results
|
| 489 |
+
probability = float(probabilities[0])
|
| 490 |
+
prediction = int(predictions[0])
|
| 491 |
+
label = "Human-Composed" if prediction == 1 else "AI-Generated"
|
| 492 |
+
|
| 493 |
+
return probability, prediction, label
|
| 494 |
+
|
| 495 |
+
def predict_batch(self, features: np.ndarray, return_details: bool = False) -> Dict:
|
| 496 |
+
"""
|
| 497 |
+
Predict AI vs Human classification for multiple songs at once.
|
| 498 |
+
|
| 499 |
+
This method is optimized for batch processing - much faster than calling
|
| 500 |
+
predict_single() multiple times.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
features: Music features for MULTIPLE songs (LLM2Vec + Spectra combined)
|
| 504 |
+
Should be 2D array with shape (num_songs, feature_dim)
|
| 505 |
+
return_details: If True, includes additional statistics and breakdowns
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Dictionary containing:
|
| 509 |
+
- 'probabilities': Confidence scores for each song (0.0 to 1.0)
|
| 510 |
+
- 'predictions': Binary classifications (0 = AI, 1 = Human)
|
| 511 |
+
- 'labels': Human-readable labels for each song
|
| 512 |
+
- 'summary': Quick stats about the batch results
|
| 513 |
+
- 'details': (if return_details=True) Additional analysis
|
| 514 |
+
|
| 515 |
+
Example:
|
| 516 |
+
# For multiple songs
|
| 517 |
+
batch_features = np.array([[0.1, 0.5, 0.3, ...], # Song 1
|
| 518 |
+
[0.2, 0.4, 0.7, ...], # Song 2
|
| 519 |
+
[0.3, 0.6, 0.1, ...]]) # Song 3
|
| 520 |
+
|
| 521 |
+
results = classifier.predict_batch(batch_features, return_details=True)
|
| 522 |
+
|
| 523 |
+
print(f"Processed {len(results['predictions'])} songs")
|
| 524 |
+
print(f"Summary: {results['summary']}")
|
| 525 |
+
|
| 526 |
+
for i, (prob, pred, label) in enumerate(zip(results['probabilities'],
|
| 527 |
+
results['predictions'],
|
| 528 |
+
results['labels'])):
|
| 529 |
+
print(f"Song {i+1}: {label} (confidence: {prob:.3f})")
|
| 530 |
+
"""
|
| 531 |
+
if not self.is_trained:
|
| 532 |
+
raise ValueError(
|
| 533 |
+
"Model must be trained before making predictions. Call train() first."
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# Ensure input is 2D
|
| 537 |
+
if features.ndim == 1:
|
| 538 |
+
raise ValueError(
|
| 539 |
+
"For batch prediction, features should be 2D (num_songs, feature_dim). "
|
| 540 |
+
"For single song, use predict_single() instead."
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
num_songs = features.shape[0]
|
| 544 |
+
logger.info(f"Processing batch of {num_songs} songs...")
|
| 545 |
+
|
| 546 |
+
# Get predictions using existing method
|
| 547 |
+
probabilities, predictions = self.predict(features)
|
| 548 |
+
|
| 549 |
+
# Convert to human-readable labels
|
| 550 |
+
labels = [
|
| 551 |
+
"Human-Composed" if pred == 1 else "AI-Generated" for pred in predictions
|
| 552 |
+
]
|
| 553 |
+
|
| 554 |
+
# Calculate summary statistics
|
| 555 |
+
num_human = np.sum(predictions == 1)
|
| 556 |
+
num_ai = np.sum(predictions == 0)
|
| 557 |
+
avg_confidence_human = (
|
| 558 |
+
np.mean(probabilities[predictions == 1]) if num_human > 0 else 0.0
|
| 559 |
+
)
|
| 560 |
+
avg_confidence_ai = (
|
| 561 |
+
np.mean(1 - probabilities[predictions == 0]) if num_ai > 0 else 0.0
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
summary = {
|
| 565 |
+
"total_songs": num_songs,
|
| 566 |
+
"human_composed": num_human,
|
| 567 |
+
"ai_generated": num_ai,
|
| 568 |
+
"human_percentage": (num_human / num_songs) * 100,
|
| 569 |
+
"ai_percentage": (num_ai / num_songs) * 100,
|
| 570 |
+
"avg_confidence_human": avg_confidence_human,
|
| 571 |
+
"avg_confidence_ai": avg_confidence_ai,
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
results = {
|
| 575 |
+
"probabilities": probabilities,
|
| 576 |
+
"predictions": predictions,
|
| 577 |
+
"labels": labels,
|
| 578 |
+
"summary": summary,
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
# Add detailed analysis if requested
|
| 582 |
+
if return_details:
|
| 583 |
+
# Confidence distribution analysis
|
| 584 |
+
high_confidence = np.sum((probabilities > 0.8) | (probabilities < 0.2))
|
| 585 |
+
medium_confidence = np.sum(
|
| 586 |
+
(probabilities >= 0.6) & (probabilities <= 0.8)
|
| 587 |
+
| (probabilities >= 0.2) & (probabilities <= 0.4)
|
| 588 |
+
)
|
| 589 |
+
low_confidence = np.sum((probabilities > 0.4) & (probabilities < 0.6))
|
| 590 |
+
|
| 591 |
+
# Most confident predictions
|
| 592 |
+
sorted_indices = np.argsort(np.abs(probabilities - 0.5))[
|
| 593 |
+
::-1
|
| 594 |
+
] # Most confident first
|
| 595 |
+
most_confident_indices = sorted_indices[: min(5, len(sorted_indices))]
|
| 596 |
+
least_confident_indices = sorted_indices[-min(5, len(sorted_indices)) :]
|
| 597 |
+
|
| 598 |
+
details = {
|
| 599 |
+
"confidence_distribution": {
|
| 600 |
+
"high_confidence": high_confidence,
|
| 601 |
+
"medium_confidence": medium_confidence,
|
| 602 |
+
"low_confidence": low_confidence,
|
| 603 |
+
},
|
| 604 |
+
"most_confident_predictions": {
|
| 605 |
+
"indices": most_confident_indices.tolist(),
|
| 606 |
+
"probabilities": probabilities[most_confident_indices].tolist(),
|
| 607 |
+
"predictions": predictions[most_confident_indices].tolist(),
|
| 608 |
+
},
|
| 609 |
+
"least_confident_predictions": {
|
| 610 |
+
"indices": least_confident_indices.tolist(),
|
| 611 |
+
"probabilities": probabilities[least_confident_indices].tolist(),
|
| 612 |
+
"predictions": predictions[least_confident_indices].tolist(),
|
| 613 |
+
},
|
| 614 |
+
"probability_stats": {
|
| 615 |
+
"mean": float(np.mean(probabilities)),
|
| 616 |
+
"std": float(np.std(probabilities)),
|
| 617 |
+
"min": float(np.min(probabilities)),
|
| 618 |
+
"max": float(np.max(probabilities)),
|
| 619 |
+
"median": float(np.median(probabilities)),
|
| 620 |
+
},
|
| 621 |
+
}
|
| 622 |
+
results["details"] = details
|
| 623 |
+
|
| 624 |
+
logger.info(
|
| 625 |
+
f"Batch prediction completed: {num_human} human, {num_ai} AI-generated"
|
| 626 |
+
)
|
| 627 |
+
return results
|
| 628 |
+
|
| 629 |
+
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, float]:
|
| 630 |
+
"""
|
| 631 |
+
Get detailed performance metrics on test data.
|
| 632 |
+
|
| 633 |
+
This gives us the final report card for our model:
|
| 634 |
+
- How accurate is it overall?
|
| 635 |
+
- How well does it detect AI-generated music?
|
| 636 |
+
- How well does it detect human-composed music?
|
| 637 |
+
- What kinds of mistakes does it make?
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
X_test: Test music features
|
| 641 |
+
y_test: True labels (0 = AI, 1 = Human)
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
Dictionary with test loss and accuracy
|
| 645 |
+
|
| 646 |
+
Also logs detailed reports including:
|
| 647 |
+
- Precision, recall, F1-score for each class
|
| 648 |
+
- Confusion matrix showing prediction vs reality
|
| 649 |
+
"""
|
| 650 |
+
probabilities, predictions = self.predict(X_test)
|
| 651 |
+
|
| 652 |
+
test_loader = self._create_data_loader(X_test, y_test, shuffle=False)
|
| 653 |
+
test_loss, test_acc = self._validate(test_loader)
|
| 654 |
+
|
| 655 |
+
results = {"test_loss": test_loss, "test_accuracy": test_acc}
|
| 656 |
+
logger.info(f"Test Results: {results}")
|
| 657 |
+
|
| 658 |
+
# Detailed performance breakdown
|
| 659 |
+
report = classification_report(
|
| 660 |
+
y_test, predictions, target_names=["AI-Generated", "Human-Composed"]
|
| 661 |
+
)
|
| 662 |
+
logger.info(f"Classification Report:\n{report}")
|
| 663 |
+
|
| 664 |
+
# Confusion matrix: shows what the model confused
|
| 665 |
+
cm = confusion_matrix(y_test, predictions)
|
| 666 |
+
logger.info(f"Confusion Matrix:\n{cm}")
|
| 667 |
+
|
| 668 |
+
return results
|
| 669 |
+
|
| 670 |
+
def save_model(self, filepath: str) -> None:
|
| 671 |
+
"""
|
| 672 |
+
Save our trained model so we can use it later.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
filepath: Where to save the model
|
| 676 |
+
|
| 677 |
+
Saves everything needed to reload the model:
|
| 678 |
+
- The learned weights
|
| 679 |
+
- Training settings
|
| 680 |
+
- Optimizer state
|
| 681 |
+
"""
|
| 682 |
+
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
| 683 |
+
torch.save(
|
| 684 |
+
{
|
| 685 |
+
"model_state_dict": self.model.state_dict(),
|
| 686 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 687 |
+
"config": self.config,
|
| 688 |
+
"is_trained": self.is_trained,
|
| 689 |
+
},
|
| 690 |
+
filepath,
|
| 691 |
+
)
|
| 692 |
+
logger.info(f"Model saved to {filepath}")
|
| 693 |
+
|
| 694 |
+
def load_model(self, filepath: str) -> None:
|
| 695 |
+
"""
|
| 696 |
+
Load a previously trained model.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
filepath: Path to our saved model file
|
| 700 |
+
|
| 701 |
+
After this, you can immediately use predict() and evaluate()
|
| 702 |
+
without needing to train again.
|
| 703 |
+
"""
|
| 704 |
+
checkpoint = torch.load(filepath, map_location=self.device)
|
| 705 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 706 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 707 |
+
self.config = checkpoint["config"]
|
| 708 |
+
self.is_trained = checkpoint.get("is_trained", True)
|
| 709 |
+
logger.info(f"Model loaded from {filepath}")
|
| 710 |
+
|
| 711 |
+
# Temporary override, while waiting for bigger dataset and for model to be trained at that
|
| 712 |
+
self.is_trained = True
|
| 713 |
+
|
| 714 |
+
def get_model_summary(self) -> None:
|
| 715 |
+
"""
|
| 716 |
+
Print out details about our model architecture.
|
| 717 |
+
|
| 718 |
+
Useful for debugging or understanding what we've built.
|
| 719 |
+
Shows the network structure and how many parameters it has.
|
| 720 |
+
"""
|
| 721 |
+
logger.info("Model Architecture:")
|
| 722 |
+
logger.info(self.model)
|
| 723 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 724 |
+
logger.info(f"Total parameters: {total_params:,}")
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def build_mlp(input_dim: int, config: Dict) -> MLPClassifier:
|
| 728 |
+
"""
|
| 729 |
+
Quick way to create an MLP classifier.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
input_dim: Size of our combined features (LLM2Vec + Spectra)
|
| 733 |
+
config: Our model settings from the YAML file
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
Ready-to-use MLPClassifier instance
|
| 737 |
+
"""
|
| 738 |
+
return MLPClassifier(input_dim, config)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def load_config(config_path: str = "config/model_config.yml") -> Dict:
|
| 742 |
+
"""
|
| 743 |
+
Load our model settings from the YAML configuration file.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
config_path: Path to our config file
|
| 747 |
+
|
| 748 |
+
Returns:
|
| 749 |
+
Dictionary with all our MLP settings (hidden layers, dropout, etc.)
|
| 750 |
+
"""
|
| 751 |
+
with open(config_path, "r") as f:
|
| 752 |
+
config = yaml.safe_load(f)
|
| 753 |
+
return config["mlp"]
|
src/musiclime/__init__.py
ADDED
|
File without changes
|
src/musiclime/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (142 Bytes). View file
|
|
|
src/musiclime/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (142 Bytes). View file
|
|
|
src/musiclime/__pycache__/explainer.cpython-312.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
src/musiclime/__pycache__/explainer.cpython-313.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
src/musiclime/__pycache__/factorization.cpython-312.pyc
ADDED
|
Binary file (5.5 kB). View file
|
|
|
src/musiclime/__pycache__/musiclime.cpython-312.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
src/musiclime/__pycache__/musiclime_wrapper.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/musiclime/__pycache__/optimized_wrapper.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
src/musiclime/__pycache__/print_utils.cpython-312.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
src/musiclime/__pycache__/text_utils.cpython-312.pyc
ADDED
|
Binary file (2.46 kB). View file
|
|
|
src/musiclime/__pycache__/true_musiclime.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
src/musiclime/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
src/musiclime/__pycache__/wrapper.cpython-312.pyc
ADDED
|
Binary file (6.1 kB). View file
|
|
|