|
|
import os |
|
|
import uuid |
|
|
import shutil |
|
|
import asyncio |
|
|
import aiohttp |
|
|
import librosa |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import json |
|
|
from fastapi import FastAPI, UploadFile, File, Form, Body |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.requests import Request |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
WORKER_URLS_LIST = [ |
|
|
"https://ezmary-sadaworker1.hf.space", |
|
|
"https://ezmary-sadaworker2.hf.space", |
|
|
"https://ezmary-sadaworker3.hf.space", |
|
|
"https://ezmary-sadaworker4.hf.space", |
|
|
"https://ezmary-sadaworker5.hf.space", |
|
|
"https://ezmary-sadaworker6.hf.space", |
|
|
"https://ezmary-sadaworker7.hf.space", |
|
|
"https://ezmary-sadaworker8.hf.space", |
|
|
"https://ezmary-sadaworker9.hf.space", |
|
|
"https://ezmary-sadaworker10.hf.space", |
|
|
"https://ezmary-sadaworker11.hf.space", |
|
|
"https://ezmary-sadaworker12.hf.space", |
|
|
"https://ezmary-sadaworker13.hf.space", |
|
|
"https://ezmary-sadaworker14.hf.space", |
|
|
"https://ezmary-sadaworker15.hf.space", |
|
|
"https://ezmary-sadaworker16.hf.space", |
|
|
"https://ezmary-sadaworker17.hf.space", |
|
|
"https://ezmary-sadaworker18.hf.space", |
|
|
"https://ezmary-sadaworker19.hf.space", |
|
|
"https://ezmary-sadaworker20.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar21.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar22.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar23.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar24.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar25.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar26.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar27.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar28.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar29.hf.space", |
|
|
"https://jdjddhdjdj-sadakareqar30.hf.space", |
|
|
"https://ahmasiaq-sadaworker31.hf.space", |
|
|
"https://ahmasiaq-sadaworker32.hf.space", |
|
|
"https://ahmasiaq-sadaworker33.hf.space", |
|
|
"https://ahmasiaq-sadaworker34.hf.space", |
|
|
"https://ahmasiaq-sadaworker35.hf.space", |
|
|
"https://ahmasiaq-sadaworker36.hf.space", |
|
|
"https://ahmasiaq-sadaworker37.hf.space", |
|
|
"https://ahmasiaq-sadaworker38.hf.space", |
|
|
"https://cartoe123-sadakaregar39.hf.space", |
|
|
"https://cartoe123-sadakaregar40.hf.space", |
|
|
"https://cartoe123-sadakaregar41.hf.space", |
|
|
"https://cartoe123-sadakaregar42.hf.space", |
|
|
"https://cartoe123-sadakaregar43.hf.space", |
|
|
"https://cartoe123-sadakaregar44.hf.space", |
|
|
"https://cartoe123-sadakaregar45.hf.space", |
|
|
"https://cartoe123-sadakaregar46.hf.space", |
|
|
"https://cartoe123-sadakaregar47.hf.space", |
|
|
"https://cartoe123-sadakaregar48.hf.space", |
|
|
"https://cartoe123-sadakaregar49.hf.space", |
|
|
"https://cartoe123-sadakaregar50.hf.space", |
|
|
"https://sadajva-sadakaregar51.hf.space", |
|
|
"https://sadajva-sadakaregar52.hf.space", |
|
|
"https://sadajva-sadakaregar53.hf.space", |
|
|
"https://sadajva-sadakaregar54.hf.space", |
|
|
"https://sadajva-sadakaregar55.hf.space", |
|
|
"https://sadajva-sadakaregar56.hf.space", |
|
|
"https://sadajva-sadakaregar57.hf.space", |
|
|
"https://sadajva-sadakaregar58.hf.space", |
|
|
"https://sadajva-sadakaregar59.hf.space", |
|
|
"https://sadajva-sadakaregar60.hf.space" |
|
|
] |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
os.makedirs("results", exist_ok=True) |
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
|
class AtomicWorkerManager: |
|
|
def __init__(self, urls): |
|
|
self.urls = urls |
|
|
self.total_workers = len(urls) |
|
|
self.current_index = 0 |
|
|
self.lock = asyncio.Lock() |
|
|
|
|
|
async def get_next_worker(self): |
|
|
async with self.lock: |
|
|
url = self.urls[self.current_index] |
|
|
self.current_index = (self.current_index + 1) % self.total_workers |
|
|
return url |
|
|
|
|
|
worker_manager = AtomicWorkerManager(WORKER_URLS_LIST) |
|
|
|
|
|
|
|
|
class ChunkInfo(BaseModel): |
|
|
index: int |
|
|
worker_url: str |
|
|
task_id: str |
|
|
|
|
|
class ProjectState(BaseModel): |
|
|
job_id: str |
|
|
total_chunks: int |
|
|
chunks: List[ChunkInfo] |
|
|
|
|
|
|
|
|
def find_split_points(audio_path, sr=24000): |
|
|
try: |
|
|
y, _ = librosa.load(audio_path, sr=sr) |
|
|
except: |
|
|
data, samplerate = sf.read(audio_path) |
|
|
if len(data.shape) > 1: data = np.mean(data, axis=1) |
|
|
y = data |
|
|
|
|
|
total_samples = len(y) |
|
|
split_points = [0] |
|
|
current_pos = 0 |
|
|
|
|
|
while current_pos < total_samples: |
|
|
target = current_pos + int(10.0 * sr) |
|
|
if target >= total_samples: |
|
|
split_points.append(total_samples) |
|
|
break |
|
|
|
|
|
search_start = max(current_pos + int(8 * sr), target - int(1*sr)) |
|
|
search_end = min(total_samples, target + int(2*sr)) |
|
|
|
|
|
if search_start >= search_end: |
|
|
search_start = current_pos + int(8 * sr) |
|
|
search_end = min(total_samples, current_pos + int(12 * sr)) |
|
|
|
|
|
region = y[search_start:search_end] |
|
|
if len(region) == 0: |
|
|
split_points.append(total_samples) |
|
|
break |
|
|
|
|
|
rms = librosa.feature.rms(y=region, frame_length=1024, hop_length=512)[0] |
|
|
min_idx = np.argmin(rms) |
|
|
cut_point = search_start + (min_idx * 512) |
|
|
|
|
|
|
|
|
if cut_point <= current_pos: |
|
|
cut_point = current_pos + int(10 * sr) |
|
|
|
|
|
split_points.append(cut_point) |
|
|
current_pos = cut_point |
|
|
|
|
|
return split_points, y |
|
|
|
|
|
async def submit_to_worker(session, worker_url, chunk_path, ref_path): |
|
|
try: |
|
|
with open(chunk_path, 'rb') as f_c, open(ref_path, 'rb') as f_r: |
|
|
data = aiohttp.FormData() |
|
|
data.add_field('source_file', f_c, filename='c.wav', content_type='audio/wav') |
|
|
data.add_field('ref_file', f_r, filename='r.wav', content_type='audio/wav') |
|
|
|
|
|
async with session.post(f"{worker_url}/submit", data=data, timeout=60) as resp: |
|
|
if resp.status == 200: |
|
|
js = await resp.json() |
|
|
return js.get("task_id") |
|
|
except Exception as e: |
|
|
print(f"Submit error to {worker_url}: {e}") |
|
|
return None |
|
|
return None |
|
|
|
|
|
async def check_worker_status(session, worker_url, task_id): |
|
|
try: |
|
|
async with session.get(f"{worker_url}/result/{task_id}") as resp: |
|
|
if resp.status == 200: |
|
|
return "completed", await resp.read() |
|
|
elif resp.status == 202: |
|
|
|
|
|
return "processing", None |
|
|
elif resp.status == 404: |
|
|
|
|
|
return "processing", None |
|
|
else: |
|
|
return "failed", None |
|
|
except: |
|
|
return "processing", None |
|
|
|
|
|
@app.get("/") |
|
|
def home(request: Request): |
|
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
|
@app.post("/upload") |
|
|
async def start_process( |
|
|
source_audio: UploadFile = File(...), |
|
|
ref_audio: UploadFile = File(...) |
|
|
): |
|
|
job_id = str(uuid.uuid4()) |
|
|
os.makedirs(f"temp/{job_id}", exist_ok=True) |
|
|
|
|
|
src_path = f"temp/{job_id}/src.wav" |
|
|
ref_path = f"temp/{job_id}/ref.wav" |
|
|
|
|
|
with open(src_path, "wb") as b: shutil.copyfileobj(source_audio.file, b) |
|
|
with open(ref_path, "wb") as b: shutil.copyfileobj(ref_audio.file, b) |
|
|
|
|
|
sr = 24000 |
|
|
|
|
|
try: |
|
|
ref_y, _ = librosa.load(ref_path, sr=sr) |
|
|
sf.write(f"temp/{job_id}/ref_clean.wav", ref_y, sr) |
|
|
except: |
|
|
|
|
|
shutil.copy(ref_path, f"temp/{job_id}/ref_clean.wav") |
|
|
|
|
|
clean_ref_path = f"temp/{job_id}/ref_clean.wav" |
|
|
|
|
|
|
|
|
split_points, y = find_split_points(src_path, sr) |
|
|
total_chunks = len(split_points) - 1 |
|
|
|
|
|
chunks_metadata = [] |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
tasks = [] |
|
|
for i in range(total_chunks): |
|
|
start = split_points[i] |
|
|
end = split_points[i+1] |
|
|
chunk_audio = y[start:end] |
|
|
|
|
|
|
|
|
if len(chunk_audio) < 0.5 * sr: |
|
|
chunks_metadata.append({"index": i, "worker_url": "skip", "task_id": "skip"}) |
|
|
continue |
|
|
|
|
|
chunk_path = f"temp/{job_id}/chunk_{i}.wav" |
|
|
sf.write(chunk_path, chunk_audio, sr) |
|
|
|
|
|
worker_url = await worker_manager.get_next_worker() |
|
|
|
|
|
tasks.append(submit_to_worker(session, worker_url, chunk_path, clean_ref_path)) |
|
|
|
|
|
chunks_metadata.append({ |
|
|
"index": i, |
|
|
"worker_url": worker_url, |
|
|
"task_id": "pending" |
|
|
}) |
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
active_task_idx = 0 |
|
|
for i in range(len(chunks_metadata)): |
|
|
if chunks_metadata[i]["task_id"] == "skip": continue |
|
|
|
|
|
task_id = results[active_task_idx] |
|
|
active_task_idx += 1 |
|
|
|
|
|
if task_id: |
|
|
chunks_metadata[i]["task_id"] = task_id |
|
|
else: |
|
|
chunks_metadata[i]["task_id"] = "failed" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"job_id": job_id, |
|
|
"total_chunks": total_chunks, |
|
|
"chunks": chunks_metadata, |
|
|
"status": "started" |
|
|
} |
|
|
|
|
|
@app.post("/check_status") |
|
|
async def check_status(project: ProjectState): |
|
|
final_filename = f"final_{project.job_id}.wav" |
|
|
if os.path.exists(f"results/{final_filename}"): |
|
|
return {"status": "completed", "progress": 100, "filename": final_filename} |
|
|
|
|
|
completed_count = 0 |
|
|
audio_parts = {} |
|
|
failed_any = False |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
tasks = [] |
|
|
task_indices = [] |
|
|
|
|
|
for chunk in project.chunks: |
|
|
if chunk.task_id == "skip": |
|
|
completed_count += 1 |
|
|
audio_parts[chunk.index] = np.zeros(2400) |
|
|
continue |
|
|
if chunk.task_id == "failed": |
|
|
failed_any = True |
|
|
continue |
|
|
|
|
|
tasks.append(check_worker_status(session, chunk.worker_url, chunk.task_id)) |
|
|
task_indices.append(chunk.index) |
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
for i, (status, data) in enumerate(results): |
|
|
idx = task_indices[i] |
|
|
if status == "completed" and data: |
|
|
completed_count += 1 |
|
|
|
|
|
audio_parts[idx] = data |
|
|
elif status == "failed": |
|
|
failed_any = True |
|
|
|
|
|
|
|
|
progress = int((completed_count / project.total_chunks) * 100) |
|
|
|
|
|
if completed_count == project.total_chunks or (completed_count > 0 and progress > 95 and failed_any): |
|
|
|
|
|
try: |
|
|
full_audio = [] |
|
|
sr = 24000 |
|
|
for i in range(project.total_chunks): |
|
|
if i in audio_parts: |
|
|
if isinstance(audio_parts[i], bytes): |
|
|
tmp_path = f"temp/part_{project.job_id}_{i}.wav" |
|
|
with open(tmp_path, "wb") as f: f.write(audio_parts[i]) |
|
|
y, _ = librosa.load(tmp_path, sr=sr) |
|
|
full_audio.append(y) |
|
|
os.remove(tmp_path) |
|
|
else: |
|
|
full_audio.append(audio_parts[i]) |
|
|
else: |
|
|
|
|
|
full_audio.append(np.zeros(sr * 2)) |
|
|
|
|
|
final_wav = np.concatenate(full_audio) |
|
|
sf.write(f"results/{final_filename}", final_wav, sr) |
|
|
|
|
|
|
|
|
shutil.rmtree(f"temp/{project.job_id}", ignore_errors=True) |
|
|
|
|
|
return {"status": "completed", "progress": 100, "filename": final_filename} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Stitch error: {e}") |
|
|
return {"status": "error", "progress": progress, "detail": str(e)} |
|
|
else: |
|
|
return {"status": "processing", "progress": progress} |
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
def download_file(filename: str): |
|
|
path = f"results/{filename}" |
|
|
if os.path.exists(path): |
|
|
return FileResponse(path, filename=filename, media_type="audio/wav") |
|
|
return {"error": "File not found"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |