stem-separator / backend /source_import.py
sourav-das's picture
Upload folder using huggingface_hub
20e0b97 verified
import json
import logging
import os
import re
import shutil
import socket
import ssl
import tempfile
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from urllib.parse import parse_qs, quote, urlparse
from urllib.request import Request, urlopen
from yt_dlp import YoutubeDL
from backend import file_manager
logger = logging.getLogger(__name__)
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/122.0.0.0 Safari/537.36"
)
IMPORT_RETRY_COUNT = int(os.getenv("IMPORT_RETRY_COUNT", "3"))
IMPORT_RETRY_BASE_DELAY_MS = int(os.getenv("IMPORT_RETRY_BASE_DELAY_MS", "800"))
IMPORT_FORCE_IPV4 = os.getenv("IMPORT_FORCE_IPV4", "1").lower() not in {"0", "false", "no"}
IMPORT_DIAGNOSTICS_ENABLED = os.getenv("IMPORT_DIAGNOSTICS_ENABLED", "1").lower() not in {"0", "false", "no"}
DEFAULT_DIAG_HOSTS = [
"www.youtube.com",
"youtube.com",
"music.youtube.com",
"youtu.be",
"google.com",
]
@dataclass
class ImportedTrack:
job_id: str
filename: str
source_url: str
resolved_url: str | None
title: str
platform: str
@dataclass
class ImportFailureContext:
stage: str
host: str | None = None
retryable: bool = False
diagnostics: dict | None = None
attempts: list[dict] = field(default_factory=list)
class SourceImportError(Exception):
def __init__(
self,
message: str,
*,
stage: str = "download",
host: str | None = None,
retryable: bool = False,
diagnostics: dict | None = None,
attempts: list[dict] | None = None,
):
super().__init__(message)
self.context = ImportFailureContext(
stage=stage,
host=host,
retryable=retryable,
diagnostics=diagnostics,
attempts=attempts or [],
)
def import_source(url: str) -> ImportedTrack:
normalized_url = normalize_url(url)
platform = classify_platform(normalized_url)
job_id = file_manager.create_job()
job_dir = file_manager.get_job_dir(job_id)
try:
if platform in {"youtube", "ytmusic"}:
title, display_filename, import_diag = download_youtube_source(
normalized_url,
job_dir,
)
metadata = {
"source_kind": platform,
"source_url": normalized_url,
"resolved_url": normalized_url,
"title": title,
"filename": display_filename,
"import_diagnostics": import_diag,
}
file_manager.save_job_metadata(job_id, metadata)
return ImportedTrack(
job_id=job_id,
filename=display_filename,
source_url=normalized_url,
resolved_url=normalized_url,
title=title,
platform=platform,
)
track_title, primary_artist = fetch_spotify_track_metadata(normalized_url)
search_query = f"{track_title} {primary_artist} audio"
resolved_url, search_diag = resolve_youtube_search(search_query)
title, display_filename, download_diag = download_youtube_source(
resolved_url,
job_dir,
title_hint=f"{track_title} - {primary_artist}",
)
metadata = {
"source_kind": "spotify",
"source_url": normalized_url,
"resolved_url": resolved_url,
"title": title,
"filename": display_filename,
"spotify_track_title": track_title,
"spotify_primary_artist": primary_artist,
"import_diagnostics": {
"search": search_diag,
"download": download_diag,
},
}
file_manager.save_job_metadata(job_id, metadata)
return ImportedTrack(
job_id=job_id,
filename=display_filename,
source_url=normalized_url,
resolved_url=resolved_url,
title=title,
platform="spotify",
)
except Exception:
file_manager.delete_job(job_id)
raise
def normalize_url(url: str) -> str:
value = url.strip()
if not value:
raise SourceImportError(
"Paste a YouTube, YouTube Music, or Spotify track link",
stage="validation",
)
return value
def classify_platform(url: str) -> str:
parsed = urlparse(url)
host = parsed.netloc.lower().removeprefix("www.")
path = parsed.path
query = parse_qs(parsed.query)
if host in {"youtube.com", "m.youtube.com", "youtu.be"}:
if "list" in query:
raise SourceImportError("Playlist links are not supported yet", stage="validation")
if host == "youtu.be":
return "youtube"
if path.startswith("/watch") or path.startswith("/shorts/"):
return "youtube"
if path.startswith("/playlist") or path.startswith("/channel/") or path.startswith("/@"):
raise SourceImportError("Only single YouTube video links are supported", stage="validation")
raise SourceImportError("Unsupported YouTube link", stage="validation")
if host == "music.youtube.com":
if "list" in query:
raise SourceImportError("Playlist links are not supported yet", stage="validation")
if path.startswith("/watch"):
return "ytmusic"
raise SourceImportError("Only single YouTube Music track links are supported", stage="validation")
if host == "open.spotify.com":
parts = [part for part in path.split("/") if part]
if len(parts) >= 2 and parts[0] == "track":
return "spotify"
if parts and parts[0] in {"album", "playlist", "artist", "show", "episode"}:
raise SourceImportError("Only single Spotify track links are supported", stage="validation")
raise SourceImportError("Unsupported Spotify link", stage="validation")
raise SourceImportError(
"Unsupported link. Use YouTube, YouTube Music, or Spotify track URLs",
stage="validation",
)
def download_youtube_source(
source_url: str,
job_dir: Path,
title_hint: str | None = None,
) -> tuple[str, str, dict]:
host = extract_host(source_url)
attempts: list[dict] = []
preflight = run_host_diagnostic(host) if host else None
retryable_error: SourceImportError | None = None
for attempt_index in range(1, IMPORT_RETRY_COUNT + 1):
temp_dir = Path(tempfile.mkdtemp(prefix="import-", dir=str(job_dir)))
attempt_started = time.perf_counter()
try:
options = build_ytdlp_options(temp_dir)
with YoutubeDL(options) as ydl:
info = ydl.extract_info(source_url, download=True)
output_path = find_downloaded_audio(temp_dir)
content = output_path.read_bytes()
display_title = sanitize_title(title_hint or info.get("title") or output_path.stem)
display_filename = f"{display_title}.wav"
file_manager.save_imported_audio(job_dir.name, display_filename, content)
attempts.append(
{
"attempt": attempt_index,
"status": "success",
"duration_seconds": time.perf_counter() - attempt_started,
}
)
return display_title, display_filename, {
"stage": "download",
"preflight": preflight,
"attempts": attempts,
}
except SourceImportError as exc:
attempts.append(build_attempt_record(attempt_index, attempt_started, exc))
retryable_error = exc if exc.context.retryable else None
if not exc.context.retryable or attempt_index == IMPORT_RETRY_COUNT:
attach_and_log_diagnostics(exc, source_url, host, preflight, attempts)
raise exc
backoff_sleep(attempt_index)
except Exception as exc:
wrapped = wrap_external_error(exc, stage="download", host=host)
attempts.append(build_attempt_record(attempt_index, attempt_started, wrapped))
retryable_error = wrapped if wrapped.context.retryable else None
if not wrapped.context.retryable or attempt_index == IMPORT_RETRY_COUNT:
attach_and_log_diagnostics(wrapped, source_url, host, preflight, attempts)
raise wrapped
backoff_sleep(attempt_index)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if retryable_error is not None:
attach_and_log_diagnostics(retryable_error, source_url, host, preflight, attempts)
raise retryable_error
raise SourceImportError(
"Temporary YouTube connectivity issue, please retry",
stage="download",
host=host,
retryable=True,
diagnostics=preflight,
attempts=attempts,
)
def resolve_youtube_search(query: str) -> tuple[str, dict]:
search_url = "https://www.youtube.com/results"
host = extract_host(search_url)
preflight = run_host_diagnostic(host) if host else None
attempts: list[dict] = []
for attempt_index in range(1, IMPORT_RETRY_COUNT + 1):
attempt_started = time.perf_counter()
try:
options = {
"quiet": True,
"no_warnings": True,
"extract_flat": "in_playlist",
"noplaylist": True,
}
if IMPORT_FORCE_IPV4:
options["source_address"] = "0.0.0.0"
with YoutubeDL(options) as ydl:
info = ydl.extract_info(f"ytsearch1:{query}", download=False)
except Exception as exc:
wrapped = wrap_external_error(exc, stage="search", host=host)
attempts.append(build_attempt_record(attempt_index, attempt_started, wrapped))
if not wrapped.context.retryable or attempt_index == IMPORT_RETRY_COUNT:
attach_and_log_diagnostics(wrapped, query, host, preflight, attempts)
raise wrapped
backoff_sleep(attempt_index)
continue
entries = info.get("entries") or []
if not entries:
error = SourceImportError(
"No matching YouTube source was found for this Spotify track",
stage="search",
host=host,
retryable=False,
)
attempts.append(build_attempt_record(attempt_index, attempt_started, error))
attach_and_log_diagnostics(error, query, host, preflight, attempts)
raise error
entry = entries[0]
resolved_url = entry.get("webpage_url") or entry.get("url")
if resolved_url and not str(resolved_url).startswith("http"):
video_id = entry.get("id") or resolved_url
resolved_url = f"https://www.youtube.com/watch?v={video_id}"
if not resolved_url:
error = SourceImportError(
"Resolved YouTube match did not include a downloadable URL",
stage="search",
host=host,
retryable=False,
)
attempts.append(build_attempt_record(attempt_index, attempt_started, error))
attach_and_log_diagnostics(error, query, host, preflight, attempts)
raise error
attempts.append(
{
"attempt": attempt_index,
"status": "success",
"duration_seconds": time.perf_counter() - attempt_started,
}
)
return resolved_url, {"stage": "search", "preflight": preflight, "attempts": attempts}
raise SourceImportError(
"Temporary YouTube connectivity issue, please retry",
stage="search",
host=host,
retryable=True,
diagnostics=preflight,
attempts=attempts,
)
def fetch_spotify_track_metadata(url: str) -> tuple[str, str]:
oembed_url = f"https://open.spotify.com/oembed?url={quote(url, safe='')}"
request = Request(oembed_url, headers={"User-Agent": USER_AGENT})
try:
with urlopen(request, timeout=20) as response:
payload = json.loads(response.read().decode("utf-8"))
title = payload.get("title", "").strip()
artist = payload.get("author_name", "").strip()
if title and artist:
return title, artist
except Exception:
pass
html = fetch_text(url)
title = first_match(
html,
[
r'<meta property="og:title" content="([^"]+)"',
r"<title>([^<]+)</title>",
],
)
artist = first_match(
html,
[
r'<meta name="music:musician_description" content="([^"]+)"',
r'"artists"\s*:\s*\[\s*\{\s*"name"\s*:\s*"([^"]+)"',
r'"byArtist"\s*:\s*\{\s*"name"\s*:\s*"([^"]+)"',
],
)
if not title or not artist:
raise SourceImportError(
"Could not read Spotify track metadata from the public page",
stage="metadata",
)
return clean_spotify_title(title), artist.strip()
def fetch_text(url: str) -> str:
request = Request(url, headers={"User-Agent": USER_AGENT})
with urlopen(request, timeout=20) as response:
return response.read().decode("utf-8", errors="ignore")
def find_downloaded_audio(temp_dir: Path) -> Path:
audio_files = sorted(
[
path
for path in temp_dir.iterdir()
if path.is_file()
and path.suffix.lower() in {".wav", ".mp3", ".m4a", ".aac", ".flac", ".opus", ".ogg"}
],
key=lambda path: path.stat().st_mtime,
reverse=True,
)
if not audio_files:
raise SourceImportError(
"Downloaded source did not produce a playable audio file",
stage="postprocess",
)
return audio_files[0]
def build_ytdlp_options(temp_dir: Path) -> dict:
options = {
"format": "bestaudio/best",
"paths": {"home": str(temp_dir)},
"outtmpl": {"default": "downloaded.%(ext)s"},
"quiet": True,
"no_warnings": True,
"noplaylist": True,
"extract_flat": False,
"retries": 1,
"fragment_retries": 1,
"postprocessors": [
{
"key": "FFmpegExtractAudio",
"preferredcodec": "wav",
}
],
}
if IMPORT_FORCE_IPV4:
options["source_address"] = "0.0.0.0"
return options
def run_network_diagnostics() -> list[dict]:
return [run_host_diagnostic(host) for host in DEFAULT_DIAG_HOSTS]
def run_host_diagnostic(host: str, port: int = 443) -> dict:
result: dict = {"host": host, "port": port}
dns_started = time.perf_counter()
try:
info = socket.getaddrinfo(
host,
port,
family=socket.AF_INET if IMPORT_FORCE_IPV4 else socket.AF_UNSPEC,
type=socket.SOCK_STREAM,
)
ip = info[0][4][0]
result["dns_ok"] = True
result["resolved_ip"] = ip
except Exception as exc:
result["dns_ok"] = False
result["dns_error"] = str(exc)
result["dns_seconds"] = time.perf_counter() - dns_started
return result
result["dns_seconds"] = time.perf_counter() - dns_started
connect_started = time.perf_counter()
try:
family, socktype, proto, _, sockaddr = info[0]
with socket.socket(family, socktype, proto) as raw_sock:
raw_sock.settimeout(5)
raw_sock.connect(sockaddr)
context = ssl.create_default_context()
with context.wrap_socket(raw_sock, server_hostname=host):
pass
result["https_ok"] = True
except Exception as exc:
result["https_ok"] = False
result["https_error"] = str(exc)
result["https_seconds"] = time.perf_counter() - connect_started
return result
def wrap_external_error(exc: Exception, *, stage: str, host: str | None) -> SourceImportError:
message = str(exc)
lower = message.lower()
if "failed to resolve" in lower or "temporary failure in name resolution" in lower or "no address associated with hostname" in lower:
return SourceImportError(
"Could not resolve YouTube host from the Space runtime",
stage="dns",
host=host,
retryable=True,
)
if any(token in lower for token in ["timed out", "timeout", "connection reset", "network is unreachable", "transporterror", "ssl"]):
return SourceImportError(
"Temporary YouTube connectivity issue, please retry",
stage=stage,
host=host,
retryable=True,
)
return SourceImportError(
f"YouTube {stage} request failed after retries",
stage=stage,
host=host,
retryable=False,
)
def attach_and_log_diagnostics(
error: SourceImportError,
target: str,
host: str | None,
preflight: dict | None,
attempts: list[dict],
) -> None:
error.context.host = host or error.context.host
error.context.diagnostics = preflight
error.context.attempts = attempts
if IMPORT_DIAGNOSTICS_ENABLED:
logger.warning(
"source_import_failure %s",
json.dumps(
{
"target": target,
"message": str(error),
"context": asdict(error.context),
}
),
)
def build_attempt_record(
attempt_index: int,
attempt_started: float,
error: SourceImportError,
) -> dict:
return {
"attempt": attempt_index,
"status": "error",
"duration_seconds": time.perf_counter() - attempt_started,
"stage": error.context.stage,
"retryable": error.context.retryable,
"message": str(error),
}
def backoff_sleep(attempt_index: int) -> None:
delay = (IMPORT_RETRY_BASE_DELAY_MS / 1000.0) * (2 ** (attempt_index - 1))
time.sleep(delay)
def extract_host(url: str) -> str | None:
parsed = urlparse(url)
return parsed.netloc or None
def first_match(text: str, patterns: list[str]) -> str | None:
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return unescape_html(match.group(1))
return None
def sanitize_title(value: str) -> str:
clean = re.sub(r"[\\/:*?\"<>|]+", " ", value)
clean = re.sub(r"\s+", " ", clean).strip().strip(".")
return clean[:120] or "Imported Track"
def clean_spotify_title(value: str) -> str:
title = value.replace(" | Spotify", "").strip()
if " - song and lyrics by " in title.lower():
title = re.split(r"\s+-\s+song and lyrics by\s+", title, flags=re.IGNORECASE)[0]
return title.strip()
def unescape_html(value: str) -> str:
return (
value.replace("&amp;", "&")
.replace("&quot;", '"')
.replace("&#x27;", "'")
.replace("&#39;", "'")
)