File size: 6,874 Bytes
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Fast API imports
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware

# Processing imports
import librosa
import io

# Utils/schemas imports
from app.schemas import (
    ErrorResponse,
    ModelInfoResponse,
    PredictionResponse,
    PredictionXAIResponse,
    WelcomeResponse,
)
from app.utils import load_config

# Model/XAI-related imports
from scripts.explain import musiclime
from scripts.predict import predict_pipeline


# Load config at startup
config = load_config()

# Extract configuration values
MAX_FILE_SIZE = config["file_upload"]["max_file_size_mb"] * 1024 * 1024
MAX_LYRICS_LENGTH = config["file_upload"]["max_lyrics_length"]
ALLOWED_AUDIO_TYPES = config["file_upload"]["allowed_audio_types"]

# Initialize fast API app with extracted config values
app = FastAPI(title=config["server"]["title"], version=config["server"]["version"])

# Initialize CORS with config values
cors_config = config["api"]["cors"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=cors_config["allow_origins"],
    allow_credentials=cors_config["allow_credentials"],
    allow_methods=cors_config["allow_methods"],
    allow_headers=cors_config["allow_headers"],
)


async def validate_audio_file(audio_file: UploadFile = File(...)):
    """Validate audio file type and size."""
    # Check file size
    audio_content = await audio_file.read()
    if len(audio_content) > MAX_FILE_SIZE:
        raise HTTPException(
            status_code=400,
            detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
        )

    # Check file type
    if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
        )

    # Reset file pointer for later use
    audio_file.file.seek(0)
    return audio_file, audio_content


def validate_lyrics(lyrics: str = Form(...)):
    """Validate lyrics length and content."""
    if len(lyrics) > MAX_LYRICS_LENGTH:
        raise HTTPException(
            status_code=400,
            detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
        )

    # Basic sanitization, remove excessive whitespace
    lyrics = lyrics.strip()
    if not lyrics:
        raise HTTPException(
            status_code=400,
            detail="Lyrics cannot be empty.",
        )

    return lyrics


@app.get("/", response_model=WelcomeResponse, tags=["Root"])
def root():
    """
    Root endpoint to check if the API is running.
    """
    return WelcomeResponse(
        status="success",
        message="Welcome to Bach or Bot API!",
        endpoints={
            "/": "This welcome message",
            "/docs": "FastAPI auto-generated API docs",
            "/api/v1/model/info": "Model information and capabilities",
            "/api/v1/predict": "POST endpoint for bach-or-bot prediction",
            "/api/v1/explain": "POST endpoint for prediction with explainability",
        },
    )


@app.post(
    "/api/v1/predict",
    response_model=PredictionResponse,
    responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
)
async def predict_music(
    lyrics: str = Depends(validate_lyrics), audio_file_data=Depends(validate_audio_file)
):
    """
    Endpoint to predict whether a music sample is human-composed or AI-generated.
    """
    try:
        # Get the audio file and content from sanitized and cleaned audio file
        audio_file, audio_content = audio_file_data

        # Load audio from uploaded file with error handling for corrupted files
        try:
            audio_data, sr = librosa.load(io.BytesIO(audio_content))
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")

        # Call MLP predict runner script to get results
        results = predict_pipeline(audio_data, lyrics)

        return PredictionResponse(
            status="success",
            lyrics=lyrics,
            audio_file_name=audio_file.filename,
            audio_content_type=audio_file.content_type,
            audio_file_size=len(audio_content),
            results=results,
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post(
    "/api/v1/explain",
    response_model=PredictionXAIResponse,
    responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
)
async def predict_music_with_xai(
    lyrics: str = Depends(validate_lyrics), audio_file_data=Depends(validate_audio_file)
):
    """
    Endpoint to predict whether a music sample is human-composed or AI-generated with explainability.
    """
    try:
        # Get the audio file and content from sanitized and cleaned audio file
        audio_file, audio_content = audio_file_data

        # Load audio from uploaded file with error handling for corrupted files
        try:
            audio_data, sr = librosa.load(io.BytesIO(audio_content))
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")

        # Call musiclime runner script to get results
        results = musiclime(audio_data, lyrics)

        return PredictionXAIResponse(
            status="success",
            lyrics=lyrics,
            audio_file_name=audio_file.filename,
            audio_content_type=audio_file.content_type,
            audio_file_size=len(audio_content),
            results=results,
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
async def get_model_info():
    """
    Get information about the current model and its capabilities.
    """
    try:
        # Get supported formats from config
        supported_formats = [fmt.replace("audio/", "") for fmt in ALLOWED_AUDIO_TYPES]

        return ModelInfoResponse(
            status="success",
            message="Model information retrieved successfully",
            model_name="Bach or Bot",
            model_version="1.0.0",  # TODO: Load from model metadata when available
            supported_formats=supported_formats,
            max_file_size_mb=config["file_upload"]["max_file_size_mb"],
            training_info={
                "dataset": "Human-Composed and AI-generated music samples",
                "architecture": "To be specified",  # TODO: Update when model is implemented
                "accuracy": "To be determined",  # TODO: Update with actual metrics
            },
            last_updated="2024-01-01T00:00:00Z",  # TODO: Update with actual timestamp
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))