""" StepFun API Client for audio editing and voice cloning """ import os import time import json import logging import tempfile import requests import base64 logger = logging.getLogger(__name__) # API Configuration API_BASE_URL = "https://api.stepfun.com/v1" POLL_INTERVAL = 2 # seconds POLL_TIMEOUT = 60 # seconds def get_api_token() -> str: """Get API token from environment variable""" token = os.getenv('STEPFUN_API_TOKEN', '') if not token: logger.warning("⚠️ STEPFUN_API_TOKEN not set in environment variables") return token class StepFunAPIClient: """StepFun API Client for audio editing""" @property def token(self) -> str: """Get token dynamically (allows runtime env var changes)""" return get_api_token() @property def base_headers(self) -> dict: """Get headers with current token""" return {"Authorization": f"Bearer {self.token}"} def upload_file(self, file_path: str) -> str | None: """Upload audio file and return file_id""" url = f"{API_BASE_URL}/files" file_name = os.path.basename(file_path) # Determine audio type if file_path.endswith('.wav'): audio_type = 'audio/wav' elif file_path.endswith('.mp3'): audio_type = 'audio/mpeg' else: audio_type = 'audio/wav' files = [ ('file', (file_name, open(file_path, 'rb'), audio_type)) ] payload = {'purpose': 'storage'} headers = self.base_headers.copy() try: response = requests.post(url, headers=headers, data=payload, files=files, timeout=60) response.raise_for_status() result = response.json() file_id = result.get("id") logger.info(f"✅ Uploaded file {file_name}, file_id: {file_id}") return file_id except Exception as e: logger.error(f"❌ Failed to upload file {file_name}: {e}") return None def query_file_status(self, file_id: str) -> bool: """Query file status, return True if success""" url = f"{API_BASE_URL}/files/{file_id}" headers = { **self.base_headers, "Content-Type": "application/json" } try: response = requests.get(url, headers=headers, timeout=30) response.raise_for_status() result = response.json() status = result.get("status", "").lower() logger.debug(f"File {file_id} status: {status}") return status == "success" except Exception as e: logger.error(f"❌ Failed to query file status: {e}") return False def wait_for_file_ready(self, file_id: str) -> bool: """Wait for file to be ready (poll status)""" start_time = time.time() while time.time() - start_time < POLL_TIMEOUT: if self.query_file_status(file_id): return True time.sleep(POLL_INTERVAL) logger.error(f"⏰ File {file_id} status query timeout ({POLL_TIMEOUT}s)") return False def audio_edit(self, file_id: str, sample_text: str, target_text: str, edit_type: str = "clone", edit_info: str = "") -> bytes | None: """Call audio edit API and return audio bytes""" url = f"{API_BASE_URL}/audio/edit" headers = { **self.base_headers, "Content-Type": "application/json" } payload = { "model": "step-tts-edit", "file_id": file_id, "sample_text": target_text, # target text from UI "text": sample_text, # prompt text from UI "edit_type": edit_type, "edit_info": edit_info or "", "response_format": "wav" } logger.info(f"🎯 Calling audio edit API: edit_type={edit_type}, edit_info={edit_info}") try: response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=120) # Check if response is audio or error content_type = response.headers.get('Content-Type', '') if 'audio' in content_type or response.status_code == 200: if len(response.content) > 1000: # Likely audio data logger.info(f"✅ Audio edit successful, received {len(response.content)} bytes") return response.content else: # Might be an error message try: error_data = response.json() logger.error(f"❌ API error: {error_data}") except: logger.error(f"❌ API error: {response.text}") return None else: logger.error(f"❌ API error: {response.status_code} - {response.text}") return None except Exception as e: logger.error(f"❌ Audio edit request failed: {e}") return None def transcribe_audio_sse(self, audio_path: str, progress_callback=None, streaming=False): """ 使用ASR SSE接口转录音频文件 Args: audio_path: 音频文件路径 progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text) streaming: 是否返回生成器进行流式更新 Returns: 如果streaming=False: 完整的转录文本 如果streaming=True: 生成器,产生增量更新和最终文本 Raises: Exception: 如果转录失败 """ if streaming: return self._transcribe_audio_sse_streaming(audio_path, progress_callback) else: return self._transcribe_audio_sse_sync(audio_path, progress_callback) def _transcribe_audio_sse_sync(self, audio_path: str, progress_callback=None) -> str: """ 同步ASR转录,返回最终文本 """ url = f"{API_BASE_URL}/audio/asr/sse" # 读取音频文件并转换为base64 try: with open(audio_path, 'rb') as audio_file: audio_data = base64.b64encode(audio_file.read()).decode('utf-8') except Exception as e: logger.error(f"❌ Failed to read audio file: {e}") raise Exception(f"Failed to read audio file: {e}") # 构建请求payload payload = { "audio": { "data": audio_data, "input": { "transcription": { "language": "zh", "prompt": "请记录下你所听到的语音内容。", "model": "step-asr", "full_rerun_on_commit": True, "enable_itn": True }, "format": { "type": "pcm", "codec": "pcm_s16le", "rate": 16000, "bits": 16, "channel": 1 } } } } headers = { **self.base_headers, "Content-Type": "application/json", "Accept": "text/event-stream" } logger.info("🎙️ Starting ASR transcription...") try: response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120) response.raise_for_status() final_text = "" accumulated_text = "" for line in response.iter_lines(decode_unicode=True): if line: line = line.strip() if line.startswith("data: "): try: # 解析SSE数据 data_str = line[6:] # 去掉 "data: " 前缀 if data_str == "[DONE]": break event_data = json.loads(data_str) event_type = event_data.get("type") if event_type == "transcript.text.delta": delta = event_data.get("delta", "") accumulated_text += delta logger.debug(f"📝 ASR delta: {delta}") # 处理增量更新回调 if progress_callback: progress_callback(accumulated_text) elif event_type == "transcript.text.done": final_text = event_data.get("text", accumulated_text) logger.info(f"✅ ASR transcription complete: {final_text}") break elif event_type == "error": error_msg = event_data.get("message", "Unknown ASR error") logger.error(f"❌ ASR error: {error_msg}") raise Exception(f"ASR API error: {error_msg}") except json.JSONDecodeError as e: logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}") continue except Exception as e: logger.error(f"❌ Error processing SSE event: {e}") raise # 如果没有获得final_text,使用accumulated_text if not final_text: final_text = accumulated_text if not final_text: raise Exception("No transcription result received") logger.info(f"🎯 Final transcription: {final_text}") return final_text except Exception as e: logger.error(f"❌ ASR transcription failed: {e}") raise Exception(f"ASR transcription failed: {e}") def _transcribe_audio_sse_streaming(self, audio_path: str, progress_callback=None): """ 流式ASR转录,返回生成器 """ url = f"{API_BASE_URL}/audio/asr" # 读取音频文件并转换为base64 try: with open(audio_path, 'rb') as audio_file: audio_data = base64.b64encode(audio_file.read()).decode('utf-8') except Exception as e: logger.error(f"❌ Failed to read audio file: {e}") yield f"[读取文件失败: {e}]" return # 构建请求payload payload = { "audio": { "data": audio_data, "input": { "transcription": { "language": "zh", "prompt": "请记录下你所听到的语音内容。", "model": "step-asr", "full_rerun_on_commit": True, "enable_itn": True }, "format": { "type": "pcm", "codec": "pcm_s16le", "rate": 16000, "bits": 16, "channel": 1 } } } } headers = { **self.base_headers, "Content-Type": "application/json", "Accept": "text/event-stream" } logger.info("🎙️ Starting ASR transcription...") try: response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120) response.raise_for_status() final_text = "" accumulated_text = "" for line in response.iter_lines(decode_unicode=True): if line: line = line.strip() if line.startswith("data: "): try: # 解析SSE数据 data_str = line[6:] # 去掉 "data: " 前缀 if data_str == "[DONE]": break event_data = json.loads(data_str) event_type = event_data.get("type") if event_type == "transcript.text.delta": delta = event_data.get("delta", "") accumulated_text += delta logger.debug(f"📝 ASR delta: {delta}") # 流式更新 yield accumulated_text if progress_callback: progress_callback(accumulated_text) elif event_type == "transcript.text.done": final_text = event_data.get("text", accumulated_text) logger.info(f"✅ ASR transcription complete: {final_text}") yield final_text return elif event_type == "error": error_msg = event_data.get("message", "Unknown ASR error") logger.error(f"❌ ASR error: {error_msg}") yield f"[ASR错误: {error_msg}]" return except json.JSONDecodeError as e: logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}") continue except Exception as e: logger.error(f"❌ Error processing SSE event: {e}") yield f"[处理错误: {str(e)}]" return # 如果没有获得final_text,使用accumulated_text if not final_text: final_text = accumulated_text if not final_text: yield f"[转录无结果]" return logger.info(f"🎯 Final transcription: {final_text}") except Exception as e: logger.error(f"❌ ASR transcription failed: {e}") yield f"[转录失败: {str(e)}]" # Global API client instance (singleton) _client = StepFunAPIClient() def get_client() -> StepFunAPIClient: """Get global API client""" return _client def process_audio(audio_input: str, prompt_text: str, target_text: str, edit_type: str, edit_info: str = None) -> str: """ Process audio using StepFun API Args: audio_input: Path to input audio file prompt_text: Text content of the input audio target_text: Target text for generation edit_type: Type of edit (clone, gender, age, etc.) edit_info: Additional edit info (e.g., "male", "happy") Returns: Path to output audio file Raises: ValueError: If API token not configured RuntimeError: If API call fails """ client = get_client() if not client.token: raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.") # 1. Upload audio file logger.info("📤 Uploading audio file...") file_id = client.upload_file(audio_input) if not file_id: raise RuntimeError("Failed to upload audio file") # 2. Wait for file to be ready logger.info("⏳ Waiting for file to be ready...") if not client.wait_for_file_ready(file_id): raise RuntimeError("File processing timeout") # 3. Call audio edit API logger.info("🎤 Calling audio edit API...") audio_bytes = client.audio_edit( file_id=file_id, sample_text=prompt_text, target_text=target_text, edit_type=edit_type, edit_info=edit_info or "" ) if not audio_bytes: raise RuntimeError("Audio edit API returned no data") # 4. Save to temp file and return path with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_bytes) output_path = f.name logger.info(f"✅ Audio saved to: {output_path}") return output_path def transcribe_audio(audio_path: str, progress_callback=None, streaming=False): """ 使用ASR SSE接口转录音频文件 Args: audio_path: 音频文件路径 progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text) streaming: 是否返回生成器进行流式更新 Returns: 如果streaming=False: 完整的转录文本 如果streaming=True: 生成器,产生增量更新和最终文本 Raises: ValueError: If API token not configured RuntimeError: If transcription fails """ if streaming: return transcribe_audio_streaming(audio_path, progress_callback) else: return transcribe_audio_sync(audio_path, progress_callback) def transcribe_audio_sync(audio_path: str, progress_callback=None) -> str: """ 同步转录音频文件,返回最终文本 Args: audio_path: 音频文件路径 progress_callback: 可选的回调函数,用于处理增量文本更新 Returns: 完整的转录文本 Raises: ValueError: If API token not configured RuntimeError: If transcription fails """ client = get_client() if not client.token: raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.") try: return client.transcribe_audio_sse(audio_path, progress_callback, streaming=False) except Exception as e: raise RuntimeError(f"Transcription failed: {e}") def transcribe_audio_streaming(audio_path: str, progress_callback=None): """ 流式转录音频文件,返回生成器 Args: audio_path: 音频文件路径 progress_callback: 可选的回调函数,用于处理增量文本更新 Yields: 增量更新和最终文本 Raises: ValueError: If API token not configured RuntimeError: If transcription fails """ client = get_client() if not client.token: yield f"[错误: API token not configured]" return try: for update in client.transcribe_audio_sse(audio_path, progress_callback, streaming=True): yield update except Exception as e: yield f"[转录失败: {str(e)}]"