stem-separator / frontend /src /hooks /useSeparation.ts
sourav-das's picture
Upload folder using huggingface_hub
a0f35dd verified
import { useReducer, useCallback, useRef } from "react";
import type {
AppState,
AppAction,
OutputFormat,
SourceKind,
StemAsset,
StemResult,
} from "../types";
import {
fetchExampleOutput,
getAudioUrl,
getDownloadAllUrl,
getDownloadUrl,
importUrl,
uploadFile,
startSeparation,
subscribeProgress,
} from "../api";
function reducer(state: AppState, action: AppAction): AppState {
switch (action.type) {
case "UPLOAD_START":
return { phase: "uploading", progress: 0, message: action.message };
case "UPLOAD_PROGRESS":
return state.phase === "uploading"
? { ...state, progress: action.progress }
: state;
case "UPLOAD_DONE":
return {
phase: "uploaded",
jobId: action.jobId,
filename: action.filename,
originalUrl: action.originalUrl,
outputFormat: action.outputFormat,
songName: action.songName,
sourceKind: action.sourceKind,
sourceUrl: action.sourceUrl,
resolvedUrl: action.resolvedUrl,
};
case "SET_OUTPUT_FORMAT":
return state.phase === "uploaded"
? { ...state, outputFormat: action.outputFormat }
: state;
case "SEPARATE_START":
if (state.phase !== "uploaded") return state;
return {
phase: "separating",
jobId: state.jobId,
filename: state.filename,
originalUrl: state.originalUrl,
outputFormat: state.outputFormat,
songName: state.songName,
sourceKind: state.sourceKind,
sourceUrl: state.sourceUrl,
resolvedUrl: state.resolvedUrl,
state: "queued",
progress: 0,
message: "Starting separation...",
};
case "SEPARATE_PROGRESS":
if (state.phase !== "separating") return state;
return {
...state,
state: action.state,
progress: action.progress,
message: action.message,
};
case "SEPARATE_DONE":
if (state.phase !== "separating") return state;
return {
phase: "done",
jobId: state.jobId,
filename: state.filename,
originalUrl: state.originalUrl,
outputFormat: state.outputFormat,
songName: state.songName,
sourceKind: state.sourceKind,
sourceUrl: state.sourceUrl,
resolvedUrl: state.resolvedUrl,
stems: action.stems,
downloadAllUrl: action.downloadAllUrl,
};
case "LOAD_EXAMPLE_DONE":
return {
phase: "example",
original: action.original,
stems: action.stems,
downloadAllUrl: action.downloadAllUrl,
};
case "ERROR":
return { phase: "error", message: action.message };
case "RESET":
return { phase: "idle" };
default:
return state;
}
}
export function useSeparation() {
const [state, dispatch] = useReducer(reducer, { phase: "idle" });
const cleanupRef = useRef<(() => void) | null>(null);
const upload = useCallback(async (file: File) => {
if (cleanupRef.current) {
cleanupRef.current();
cleanupRef.current = null;
}
try {
dispatch({ type: "UPLOAD_START", message: "Uploading audio file..." });
const result = await uploadFile(file, (progress) => {
dispatch({ type: "UPLOAD_PROGRESS", progress });
});
dispatch({
type: "UPLOAD_DONE",
jobId: result.job_id,
filename: result.filename,
originalUrl: getAudioUrl(
result.job_id,
`input${getExtFromFilename(result.filename)}`
),
outputFormat: getDefaultOutputFormat(result.filename),
songName: stripExtension(result.filename),
sourceKind: "file",
});
} catch (err) {
dispatch({
type: "ERROR",
message: err instanceof Error ? err.message : "Upload failed",
});
}
}, []);
const separate = useCallback(
async (
jobId: string,
stems: string[],
outputFormat: OutputFormat
) => {
try {
dispatch({ type: "SEPARATE_START" });
await startSeparation(jobId, stems, outputFormat);
// Subscribe to progress
cleanupRef.current = subscribeProgress(
jobId,
(event) => {
dispatch({
type: "SEPARATE_PROGRESS",
state: event.state,
progress: event.progress,
message: event.message,
});
},
(stemResults) => {
dispatch({
type: "SEPARATE_DONE",
stems: resolveJobStemAssets(jobId, stemResults),
downloadAllUrl: getDownloadAllUrl(jobId),
});
},
(error) => {
dispatch({ type: "ERROR", message: error });
}
);
} catch (err) {
dispatch({
type: "ERROR",
message:
err instanceof Error ? err.message : "Separation failed",
});
}
},
[]
);
const importFromUrl = useCallback(async (url: string) => {
if (cleanupRef.current) {
cleanupRef.current();
cleanupRef.current = null;
}
try {
dispatch({ type: "UPLOAD_START", message: "Downloading source track..." });
const result = await importUrl(url);
dispatch({
type: "UPLOAD_DONE",
jobId: result.job_id,
filename: result.filename,
originalUrl: getAudioUrl(
result.job_id,
`input${getExtFromFilename(result.filename)}`
),
outputFormat: getDefaultOutputFormat(result.filename),
songName: result.title || stripExtension(result.filename),
sourceKind: result.platform,
sourceUrl: result.source_url,
resolvedUrl: result.resolved_url,
});
} catch (err) {
dispatch({
type: "ERROR",
message: err instanceof Error ? err.message : "Import failed",
});
}
}, []);
const loadExample = useCallback(async () => {
if (cleanupRef.current) {
cleanupRef.current();
cleanupRef.current = null;
}
try {
const example = await fetchExampleOutput();
dispatch({
type: "LOAD_EXAMPLE_DONE",
original: {
...example.original,
songName: example.song || example.original.songName,
},
stems: example.stems,
downloadAllUrl: example.downloadAllUrl,
});
} catch (err) {
dispatch({
type: "ERROR",
message:
err instanceof Error ? err.message : "Failed to load example output",
});
}
}, []);
const setOutputFormat = useCallback((outputFormat: OutputFormat) => {
dispatch({ type: "SET_OUTPUT_FORMAT", outputFormat });
}, []);
const reset = useCallback(() => {
if (cleanupRef.current) {
cleanupRef.current();
cleanupRef.current = null;
}
dispatch({ type: "RESET" });
}, []);
return {
state,
upload,
importFromUrl,
separate,
loadExample,
setOutputFormat,
reset,
};
}
function getExtFromFilename(filename: string): string {
const dot = filename.lastIndexOf(".");
return dot >= 0 ? filename.slice(dot) : ".wav";
}
function stripExtension(filename: string): string {
const dot = filename.lastIndexOf(".");
return dot >= 0 ? filename.slice(0, dot) : filename;
}
function getDefaultOutputFormat(filename: string): OutputFormat {
const ext = getExtFromFilename(filename).toLowerCase();
if (ext === ".wav") return "wav";
if (ext === ".mp3") return "mp3";
if (ext === ".aac") return "aac";
return "wav";
}
function resolveJobStemAssets(jobId: string, stems: StemResult[]): StemAsset[] {
return stems.map((stem) => ({
...stem,
audioUrl: getAudioUrl(jobId, stem.filename),
downloadUrl: getDownloadUrl(jobId, stem.filename),
}));
}
export function getSourceLabel(sourceKind: SourceKind): string {
switch (sourceKind) {
case "youtube":
return "YouTube import";
case "ytmusic":
return "YouTube Music import";
case "spotify":
return "Spotify import";
default:
return "Uploaded file";
}
}