sada / app.py
Ezmary's picture
Update app.py
e6a8d02 verified
import os
import uuid
import shutil
import asyncio
import aiohttp
import librosa
import numpy as np
import soundfile as sf
import json
from fastapi import FastAPI, UploadFile, File, Form, Body
from fastapi.responses import FileResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.requests import Request
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
# --- لیست کارگرها ---
WORKER_URLS_LIST = [
"https://ezmary-sadaworker1.hf.space",
"https://ezmary-sadaworker2.hf.space",
"https://ezmary-sadaworker3.hf.space",
"https://ezmary-sadaworker4.hf.space",
"https://ezmary-sadaworker5.hf.space",
"https://ezmary-sadaworker6.hf.space",
"https://ezmary-sadaworker7.hf.space",
"https://ezmary-sadaworker8.hf.space",
"https://ezmary-sadaworker9.hf.space",
"https://ezmary-sadaworker10.hf.space",
"https://ezmary-sadaworker11.hf.space",
"https://ezmary-sadaworker12.hf.space",
"https://ezmary-sadaworker13.hf.space",
"https://ezmary-sadaworker14.hf.space",
"https://ezmary-sadaworker15.hf.space",
"https://ezmary-sadaworker16.hf.space",
"https://ezmary-sadaworker17.hf.space",
"https://ezmary-sadaworker18.hf.space",
"https://ezmary-sadaworker19.hf.space",
"https://ezmary-sadaworker20.hf.space",
"https://jdjddhdjdj-sadakareqar21.hf.space",
"https://jdjddhdjdj-sadakareqar22.hf.space",
"https://jdjddhdjdj-sadakareqar23.hf.space",
"https://jdjddhdjdj-sadakareqar24.hf.space",
"https://jdjddhdjdj-sadakareqar25.hf.space",
"https://jdjddhdjdj-sadakareqar26.hf.space",
"https://jdjddhdjdj-sadakareqar27.hf.space",
"https://jdjddhdjdj-sadakareqar28.hf.space",
"https://jdjddhdjdj-sadakareqar29.hf.space",
"https://jdjddhdjdj-sadakareqar30.hf.space",
"https://ahmasiaq-sadaworker31.hf.space",
"https://ahmasiaq-sadaworker32.hf.space",
"https://ahmasiaq-sadaworker33.hf.space",
"https://ahmasiaq-sadaworker34.hf.space",
"https://ahmasiaq-sadaworker35.hf.space",
"https://ahmasiaq-sadaworker36.hf.space",
"https://ahmasiaq-sadaworker37.hf.space",
"https://ahmasiaq-sadaworker38.hf.space",
"https://cartoe123-sadakaregar39.hf.space",
"https://cartoe123-sadakaregar40.hf.space",
"https://cartoe123-sadakaregar41.hf.space",
"https://cartoe123-sadakaregar42.hf.space",
"https://cartoe123-sadakaregar43.hf.space",
"https://cartoe123-sadakaregar44.hf.space",
"https://cartoe123-sadakaregar45.hf.space",
"https://cartoe123-sadakaregar46.hf.space",
"https://cartoe123-sadakaregar47.hf.space",
"https://cartoe123-sadakaregar48.hf.space",
"https://cartoe123-sadakaregar49.hf.space",
"https://cartoe123-sadakaregar50.hf.space",
"https://sadajva-sadakaregar51.hf.space",
"https://sadajva-sadakaregar52.hf.space",
"https://sadajva-sadakaregar53.hf.space",
"https://sadajva-sadakaregar54.hf.space",
"https://sadajva-sadakaregar55.hf.space",
"https://sadajva-sadakaregar56.hf.space",
"https://sadajva-sadakaregar57.hf.space",
"https://sadajva-sadakaregar58.hf.space",
"https://sadajva-sadakaregar59.hf.space",
"https://sadajva-sadakaregar60.hf.space"
]
os.makedirs("temp", exist_ok=True)
os.makedirs("results", exist_ok=True)
templates = Jinja2Templates(directory="templates")
# --- مدیریت کارگرها (چرخشی) ---
class AtomicWorkerManager:
def __init__(self, urls):
self.urls = urls
self.total_workers = len(urls)
self.current_index = 0
self.lock = asyncio.Lock()
async def get_next_worker(self):
async with self.lock:
url = self.urls[self.current_index]
self.current_index = (self.current_index + 1) % self.total_workers
return url
worker_manager = AtomicWorkerManager(WORKER_URLS_LIST)
# --- مدل‌های داده ---
class ChunkInfo(BaseModel):
index: int
worker_url: str
task_id: str
class ProjectState(BaseModel):
job_id: str
total_chunks: int
chunks: List[ChunkInfo]
# --- توابع کمکی ---
def find_split_points(audio_path, sr=24000):
try:
y, _ = librosa.load(audio_path, sr=sr)
except:
data, samplerate = sf.read(audio_path)
if len(data.shape) > 1: data = np.mean(data, axis=1)
y = data
total_samples = len(y)
split_points = [0]
current_pos = 0
while current_pos < total_samples:
target = current_pos + int(10.0 * sr)
if target >= total_samples:
split_points.append(total_samples)
break
search_start = max(current_pos + int(8 * sr), target - int(1*sr))
search_end = min(total_samples, target + int(2*sr))
if search_start >= search_end:
search_start = current_pos + int(8 * sr)
search_end = min(total_samples, current_pos + int(12 * sr))
region = y[search_start:search_end]
if len(region) == 0:
split_points.append(total_samples)
break
rms = librosa.feature.rms(y=region, frame_length=1024, hop_length=512)[0]
min_idx = np.argmin(rms)
cut_point = search_start + (min_idx * 512)
# اگر نقطه برش خیلی نزدیک است، فورس کن
if cut_point <= current_pos:
cut_point = current_pos + int(10 * sr)
split_points.append(cut_point)
current_pos = cut_point
return split_points, y
async def submit_to_worker(session, worker_url, chunk_path, ref_path):
try:
with open(chunk_path, 'rb') as f_c, open(ref_path, 'rb') as f_r:
data = aiohttp.FormData()
data.add_field('source_file', f_c, filename='c.wav', content_type='audio/wav')
data.add_field('ref_file', f_r, filename='r.wav', content_type='audio/wav')
# تایم‌اوت بالا برای آپلود فایل‌های بزرگ
async with session.post(f"{worker_url}/submit", data=data, timeout=60) as resp:
if resp.status == 200:
js = await resp.json()
return js.get("task_id")
except Exception as e:
print(f"Submit error to {worker_url}: {e}")
return None
return None
async def check_worker_status(session, worker_url, task_id):
try:
async with session.get(f"{worker_url}/result/{task_id}") as resp:
if resp.status == 200:
return "completed", await resp.read()
elif resp.status == 202:
# 202 شامل 'queued' و 'processing' است
return "processing", None
elif resp.status == 404:
# شاید تسک هنوز ثبت نشده یا پاک شده
return "processing", None
else:
return "failed", None
except:
return "processing", None # اگر ارتباط قطع شد، فرض کن هنوز داره کار میکنه
@app.get("/")
def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload")
async def start_process(
source_audio: UploadFile = File(...),
ref_audio: UploadFile = File(...)
):
job_id = str(uuid.uuid4())
os.makedirs(f"temp/{job_id}", exist_ok=True)
src_path = f"temp/{job_id}/src.wav"
ref_path = f"temp/{job_id}/ref.wav"
with open(src_path, "wb") as b: shutil.copyfileobj(source_audio.file, b)
with open(ref_path, "wb") as b: shutil.copyfileobj(ref_audio.file, b)
sr = 24000
# آماده‌سازی رفرنس
try:
ref_y, _ = librosa.load(ref_path, sr=sr)
sf.write(f"temp/{job_id}/ref_clean.wav", ref_y, sr)
except:
# اگر خطا داد همان فایل اصلی را استفاده کن
shutil.copy(ref_path, f"temp/{job_id}/ref_clean.wav")
clean_ref_path = f"temp/{job_id}/ref_clean.wav"
# برش فایل
split_points, y = find_split_points(src_path, sr)
total_chunks = len(split_points) - 1
chunks_metadata = []
async with aiohttp.ClientSession() as session:
tasks = []
for i in range(total_chunks):
start = split_points[i]
end = split_points[i+1]
chunk_audio = y[start:end]
# رد کردن تکه‌های خیلی کوتاه (زیر 0.5 ثانیه)
if len(chunk_audio) < 0.5 * sr:
chunks_metadata.append({"index": i, "worker_url": "skip", "task_id": "skip"})
continue
chunk_path = f"temp/{job_id}/chunk_{i}.wav"
sf.write(chunk_path, chunk_audio, sr)
worker_url = await worker_manager.get_next_worker()
tasks.append(submit_to_worker(session, worker_url, chunk_path, clean_ref_path))
chunks_metadata.append({
"index": i,
"worker_url": worker_url,
"task_id": "pending"
})
# ارسال موازی (صف کارگرها مدیریت می‌کند)
results = await asyncio.gather(*tasks)
active_task_idx = 0
for i in range(len(chunks_metadata)):
if chunks_metadata[i]["task_id"] == "skip": continue
task_id = results[active_task_idx]
active_task_idx += 1
if task_id:
chunks_metadata[i]["task_id"] = task_id
else:
chunks_metadata[i]["task_id"] = "failed"
# نگه داشتن فایل رفرنس برای اطمینان، اما پوشه تمپ را می‌توان بعدا پاک کرد
# shutil.rmtree(f"temp/{job_id}", ignore_errors=True)
return {
"job_id": job_id,
"total_chunks": total_chunks,
"chunks": chunks_metadata,
"status": "started"
}
@app.post("/check_status")
async def check_status(project: ProjectState):
final_filename = f"final_{project.job_id}.wav"
if os.path.exists(f"results/{final_filename}"):
return {"status": "completed", "progress": 100, "filename": final_filename}
completed_count = 0
audio_parts = {}
failed_any = False
async with aiohttp.ClientSession() as session:
tasks = []
task_indices = []
for chunk in project.chunks:
if chunk.task_id == "skip":
completed_count += 1
audio_parts[chunk.index] = np.zeros(2400) # سکوت کوتاه
continue
if chunk.task_id == "failed":
failed_any = True
continue
tasks.append(check_worker_status(session, chunk.worker_url, chunk.task_id))
task_indices.append(chunk.index)
results = await asyncio.gather(*tasks)
for i, (status, data) in enumerate(results):
idx = task_indices[i]
if status == "completed" and data:
completed_count += 1
# ذخیره موقت بایت‌ها برای پردازش بعدی
audio_parts[idx] = data
elif status == "failed":
failed_any = True
# if processing, just wait
progress = int((completed_count / project.total_chunks) * 100)
if completed_count == project.total_chunks or (completed_count > 0 and progress > 95 and failed_any):
# اگر همه تمام شدند یا اکثرشان تمام شدند
try:
full_audio = []
sr = 24000
for i in range(project.total_chunks):
if i in audio_parts:
if isinstance(audio_parts[i], bytes):
tmp_path = f"temp/part_{project.job_id}_{i}.wav"
with open(tmp_path, "wb") as f: f.write(audio_parts[i])
y, _ = librosa.load(tmp_path, sr=sr)
full_audio.append(y)
os.remove(tmp_path)
else:
full_audio.append(audio_parts[i]) # سکوت
else:
# اگر قطعه‌ای خراب شده بود، سکوت بگذار
full_audio.append(np.zeros(sr * 2))
final_wav = np.concatenate(full_audio)
sf.write(f"results/{final_filename}", final_wav, sr)
# پاکسازی نهایی
shutil.rmtree(f"temp/{project.job_id}", ignore_errors=True)
return {"status": "completed", "progress": 100, "filename": final_filename}
except Exception as e:
print(f"Stitch error: {e}")
return {"status": "error", "progress": progress, "detail": str(e)}
else:
return {"status": "processing", "progress": progress}
@app.get("/download/{filename}")
def download_file(filename: str):
path = f"results/{filename}"
if os.path.exists(path):
return FileResponse(path, filename=filename, media_type="audio/wav")
return {"error": "File not found"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)