Spaces:
Running
Running
| """ | |
| StepFun API Client for audio editing and voice cloning | |
| """ | |
| import os | |
| import time | |
| import json | |
| import logging | |
| import tempfile | |
| import requests | |
| import base64 | |
| logger = logging.getLogger(__name__) | |
| # API Configuration | |
| API_BASE_URL = "https://api.stepfun.com/v1" | |
| POLL_INTERVAL = 2 # seconds | |
| POLL_TIMEOUT = 60 # seconds | |
| def get_api_token() -> str: | |
| """Get API token from environment variable""" | |
| token = os.getenv('STEPFUN_API_TOKEN', '') | |
| if not token: | |
| logger.warning("⚠️ STEPFUN_API_TOKEN not set in environment variables") | |
| return token | |
| class StepFunAPIClient: | |
| """StepFun API Client for audio editing""" | |
| def token(self) -> str: | |
| """Get token dynamically (allows runtime env var changes)""" | |
| return get_api_token() | |
| def base_headers(self) -> dict: | |
| """Get headers with current token""" | |
| return {"Authorization": f"Bearer {self.token}"} | |
| def upload_file(self, file_path: str) -> str | None: | |
| """Upload audio file and return file_id""" | |
| url = f"{API_BASE_URL}/files" | |
| file_name = os.path.basename(file_path) | |
| # Determine audio type | |
| if file_path.endswith('.wav'): | |
| audio_type = 'audio/wav' | |
| elif file_path.endswith('.mp3'): | |
| audio_type = 'audio/mpeg' | |
| else: | |
| audio_type = 'audio/wav' | |
| files = [ | |
| ('file', (file_name, open(file_path, 'rb'), audio_type)) | |
| ] | |
| payload = {'purpose': 'storage'} | |
| headers = self.base_headers.copy() | |
| try: | |
| response = requests.post(url, headers=headers, data=payload, files=files, timeout=60) | |
| response.raise_for_status() | |
| result = response.json() | |
| file_id = result.get("id") | |
| logger.info(f"✅ Uploaded file {file_name}, file_id: {file_id}") | |
| return file_id | |
| except Exception as e: | |
| logger.error(f"❌ Failed to upload file {file_name}: {e}") | |
| return None | |
| def query_file_status(self, file_id: str) -> bool: | |
| """Query file status, return True if success""" | |
| url = f"{API_BASE_URL}/files/{file_id}" | |
| headers = { | |
| **self.base_headers, | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.get(url, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| result = response.json() | |
| status = result.get("status", "").lower() | |
| logger.debug(f"File {file_id} status: {status}") | |
| return status == "success" | |
| except Exception as e: | |
| logger.error(f"❌ Failed to query file status: {e}") | |
| return False | |
| def wait_for_file_ready(self, file_id: str) -> bool: | |
| """Wait for file to be ready (poll status)""" | |
| start_time = time.time() | |
| while time.time() - start_time < POLL_TIMEOUT: | |
| if self.query_file_status(file_id): | |
| return True | |
| time.sleep(POLL_INTERVAL) | |
| logger.error(f"⏰ File {file_id} status query timeout ({POLL_TIMEOUT}s)") | |
| return False | |
| def audio_edit(self, file_id: str, sample_text: str, target_text: str, | |
| edit_type: str = "clone", edit_info: str = "") -> bytes | None: | |
| """Call audio edit API and return audio bytes""" | |
| url = f"{API_BASE_URL}/audio/edit" | |
| headers = { | |
| **self.base_headers, | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "step-tts-edit", | |
| "file_id": file_id, | |
| "sample_text": target_text, # target text from UI | |
| "text": sample_text, # prompt text from UI | |
| "edit_type": edit_type, | |
| "edit_info": edit_info or "", | |
| "response_format": "wav" | |
| } | |
| logger.info(f"🎯 Calling audio edit API: edit_type={edit_type}, edit_info={edit_info}") | |
| try: | |
| response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=120) | |
| # Check if response is audio or error | |
| content_type = response.headers.get('Content-Type', '') | |
| if 'audio' in content_type or response.status_code == 200: | |
| if len(response.content) > 1000: # Likely audio data | |
| logger.info(f"✅ Audio edit successful, received {len(response.content)} bytes") | |
| return response.content | |
| else: | |
| # Might be an error message | |
| try: | |
| error_data = response.json() | |
| logger.error(f"❌ API error: {error_data}") | |
| except: | |
| logger.error(f"❌ API error: {response.text}") | |
| return None | |
| else: | |
| logger.error(f"❌ API error: {response.status_code} - {response.text}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"❌ Audio edit request failed: {e}") | |
| return None | |
| def transcribe_audio_sse(self, audio_path: str, progress_callback=None, streaming=False): | |
| """ | |
| 使用ASR SSE接口转录音频文件 | |
| Args: | |
| audio_path: 音频文件路径 | |
| progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text) | |
| streaming: 是否返回生成器进行流式更新 | |
| Returns: | |
| 如果streaming=False: 完整的转录文本 | |
| 如果streaming=True: 生成器,产生增量更新和最终文本 | |
| Raises: | |
| Exception: 如果转录失败 | |
| """ | |
| if streaming: | |
| return self._transcribe_audio_sse_streaming(audio_path, progress_callback) | |
| else: | |
| return self._transcribe_audio_sse_sync(audio_path, progress_callback) | |
| def _transcribe_audio_sse_sync(self, audio_path: str, progress_callback=None) -> str: | |
| """ | |
| 同步ASR转录,返回最终文本 | |
| """ | |
| url = f"{API_BASE_URL}/audio/asr/sse" | |
| # 读取音频文件并转换为base64 | |
| try: | |
| with open(audio_path, 'rb') as audio_file: | |
| audio_data = base64.b64encode(audio_file.read()).decode('utf-8') | |
| except Exception as e: | |
| logger.error(f"❌ Failed to read audio file: {e}") | |
| raise Exception(f"Failed to read audio file: {e}") | |
| # 构建请求payload | |
| payload = { | |
| "audio": { | |
| "data": audio_data, | |
| "input": { | |
| "transcription": { | |
| "language": "zh", | |
| "prompt": "请记录下你所听到的语音内容。", | |
| "model": "step-asr", | |
| "full_rerun_on_commit": True, | |
| "enable_itn": True | |
| }, | |
| "format": { | |
| "type": "pcm", | |
| "codec": "pcm_s16le", | |
| "rate": 16000, | |
| "bits": 16, | |
| "channel": 1 | |
| } | |
| } | |
| } | |
| } | |
| headers = { | |
| **self.base_headers, | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| logger.info("🎙️ Starting ASR transcription...") | |
| try: | |
| response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120) | |
| response.raise_for_status() | |
| final_text = "" | |
| accumulated_text = "" | |
| for line in response.iter_lines(decode_unicode=True): | |
| if line: | |
| line = line.strip() | |
| if line.startswith("data: "): | |
| try: | |
| # 解析SSE数据 | |
| data_str = line[6:] # 去掉 "data: " 前缀 | |
| if data_str == "[DONE]": | |
| break | |
| event_data = json.loads(data_str) | |
| event_type = event_data.get("type") | |
| if event_type == "transcript.text.delta": | |
| delta = event_data.get("delta", "") | |
| accumulated_text += delta | |
| logger.debug(f"📝 ASR delta: {delta}") | |
| # 处理增量更新回调 | |
| if progress_callback: | |
| progress_callback(accumulated_text) | |
| elif event_type == "transcript.text.done": | |
| final_text = event_data.get("text", accumulated_text) | |
| logger.info(f"✅ ASR transcription complete: {final_text}") | |
| break | |
| elif event_type == "error": | |
| error_msg = event_data.get("message", "Unknown ASR error") | |
| logger.error(f"❌ ASR error: {error_msg}") | |
| raise Exception(f"ASR API error: {error_msg}") | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}") | |
| continue | |
| except Exception as e: | |
| logger.error(f"❌ Error processing SSE event: {e}") | |
| raise | |
| # 如果没有获得final_text,使用accumulated_text | |
| if not final_text: | |
| final_text = accumulated_text | |
| if not final_text: | |
| raise Exception("No transcription result received") | |
| logger.info(f"🎯 Final transcription: {final_text}") | |
| return final_text | |
| except Exception as e: | |
| logger.error(f"❌ ASR transcription failed: {e}") | |
| raise Exception(f"ASR transcription failed: {e}") | |
| def _transcribe_audio_sse_streaming(self, audio_path: str, progress_callback=None): | |
| """ | |
| 流式ASR转录,返回生成器 | |
| """ | |
| url = f"{API_BASE_URL}/audio/asr" | |
| # 读取音频文件并转换为base64 | |
| try: | |
| with open(audio_path, 'rb') as audio_file: | |
| audio_data = base64.b64encode(audio_file.read()).decode('utf-8') | |
| except Exception as e: | |
| logger.error(f"❌ Failed to read audio file: {e}") | |
| yield f"[读取文件失败: {e}]" | |
| return | |
| # 构建请求payload | |
| payload = { | |
| "audio": { | |
| "data": audio_data, | |
| "input": { | |
| "transcription": { | |
| "language": "zh", | |
| "prompt": "请记录下你所听到的语音内容。", | |
| "model": "step-asr", | |
| "full_rerun_on_commit": True, | |
| "enable_itn": True | |
| }, | |
| "format": { | |
| "type": "pcm", | |
| "codec": "pcm_s16le", | |
| "rate": 16000, | |
| "bits": 16, | |
| "channel": 1 | |
| } | |
| } | |
| } | |
| } | |
| headers = { | |
| **self.base_headers, | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| logger.info("🎙️ Starting ASR transcription...") | |
| try: | |
| response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120) | |
| response.raise_for_status() | |
| final_text = "" | |
| accumulated_text = "" | |
| for line in response.iter_lines(decode_unicode=True): | |
| if line: | |
| line = line.strip() | |
| if line.startswith("data: "): | |
| try: | |
| # 解析SSE数据 | |
| data_str = line[6:] # 去掉 "data: " 前缀 | |
| if data_str == "[DONE]": | |
| break | |
| event_data = json.loads(data_str) | |
| event_type = event_data.get("type") | |
| if event_type == "transcript.text.delta": | |
| delta = event_data.get("delta", "") | |
| accumulated_text += delta | |
| logger.debug(f"📝 ASR delta: {delta}") | |
| # 流式更新 | |
| yield accumulated_text | |
| if progress_callback: | |
| progress_callback(accumulated_text) | |
| elif event_type == "transcript.text.done": | |
| final_text = event_data.get("text", accumulated_text) | |
| logger.info(f"✅ ASR transcription complete: {final_text}") | |
| yield final_text | |
| return | |
| elif event_type == "error": | |
| error_msg = event_data.get("message", "Unknown ASR error") | |
| logger.error(f"❌ ASR error: {error_msg}") | |
| yield f"[ASR错误: {error_msg}]" | |
| return | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}") | |
| continue | |
| except Exception as e: | |
| logger.error(f"❌ Error processing SSE event: {e}") | |
| yield f"[处理错误: {str(e)}]" | |
| return | |
| # 如果没有获得final_text,使用accumulated_text | |
| if not final_text: | |
| final_text = accumulated_text | |
| if not final_text: | |
| yield f"[转录无结果]" | |
| return | |
| logger.info(f"🎯 Final transcription: {final_text}") | |
| except Exception as e: | |
| logger.error(f"❌ ASR transcription failed: {e}") | |
| yield f"[转录失败: {str(e)}]" | |
| # Global API client instance (singleton) | |
| _client = StepFunAPIClient() | |
| def get_client() -> StepFunAPIClient: | |
| """Get global API client""" | |
| return _client | |
| def process_audio(audio_input: str, prompt_text: str, target_text: str, | |
| edit_type: str, edit_info: str = None) -> str: | |
| """ | |
| Process audio using StepFun API | |
| Args: | |
| audio_input: Path to input audio file | |
| prompt_text: Text content of the input audio | |
| target_text: Target text for generation | |
| edit_type: Type of edit (clone, gender, age, etc.) | |
| edit_info: Additional edit info (e.g., "male", "happy") | |
| Returns: | |
| Path to output audio file | |
| Raises: | |
| ValueError: If API token not configured | |
| RuntimeError: If API call fails | |
| """ | |
| client = get_client() | |
| if not client.token: | |
| raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.") | |
| # 1. Upload audio file | |
| logger.info("📤 Uploading audio file...") | |
| file_id = client.upload_file(audio_input) | |
| if not file_id: | |
| raise RuntimeError("Failed to upload audio file") | |
| # 2. Wait for file to be ready | |
| logger.info("⏳ Waiting for file to be ready...") | |
| if not client.wait_for_file_ready(file_id): | |
| raise RuntimeError("File processing timeout") | |
| # 3. Call audio edit API | |
| logger.info("🎤 Calling audio edit API...") | |
| audio_bytes = client.audio_edit( | |
| file_id=file_id, | |
| sample_text=prompt_text, | |
| target_text=target_text, | |
| edit_type=edit_type, | |
| edit_info=edit_info or "" | |
| ) | |
| if not audio_bytes: | |
| raise RuntimeError("Audio edit API returned no data") | |
| # 4. Save to temp file and return path | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| f.write(audio_bytes) | |
| output_path = f.name | |
| logger.info(f"✅ Audio saved to: {output_path}") | |
| return output_path | |
| def transcribe_audio(audio_path: str, progress_callback=None, streaming=False): | |
| """ | |
| 使用ASR SSE接口转录音频文件 | |
| Args: | |
| audio_path: 音频文件路径 | |
| progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text) | |
| streaming: 是否返回生成器进行流式更新 | |
| Returns: | |
| 如果streaming=False: 完整的转录文本 | |
| 如果streaming=True: 生成器,产生增量更新和最终文本 | |
| Raises: | |
| ValueError: If API token not configured | |
| RuntimeError: If transcription fails | |
| """ | |
| if streaming: | |
| return transcribe_audio_streaming(audio_path, progress_callback) | |
| else: | |
| return transcribe_audio_sync(audio_path, progress_callback) | |
| def transcribe_audio_sync(audio_path: str, progress_callback=None) -> str: | |
| """ | |
| 同步转录音频文件,返回最终文本 | |
| Args: | |
| audio_path: 音频文件路径 | |
| progress_callback: 可选的回调函数,用于处理增量文本更新 | |
| Returns: | |
| 完整的转录文本 | |
| Raises: | |
| ValueError: If API token not configured | |
| RuntimeError: If transcription fails | |
| """ | |
| client = get_client() | |
| if not client.token: | |
| raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.") | |
| try: | |
| return client.transcribe_audio_sse(audio_path, progress_callback, streaming=False) | |
| except Exception as e: | |
| raise RuntimeError(f"Transcription failed: {e}") | |
| def transcribe_audio_streaming(audio_path: str, progress_callback=None): | |
| """ | |
| 流式转录音频文件,返回生成器 | |
| Args: | |
| audio_path: 音频文件路径 | |
| progress_callback: 可选的回调函数,用于处理增量文本更新 | |
| Yields: | |
| 增量更新和最终文本 | |
| Raises: | |
| ValueError: If API token not configured | |
| RuntimeError: If transcription fails | |
| """ | |
| client = get_client() | |
| if not client.token: | |
| yield f"[错误: API token not configured]" | |
| return | |
| try: | |
| for update in client.transcribe_audio_sse(audio_path, progress_callback, streaming=True): | |
| yield update | |
| except Exception as e: | |
| yield f"[转录失败: {str(e)}]" | |