Step-Audio-EditX / stepfun_api.py
xieli
feat: fix url
eea7b6a
"""
StepFun API Client for audio editing and voice cloning
"""
import os
import time
import json
import logging
import tempfile
import requests
import base64
logger = logging.getLogger(__name__)
# API Configuration
API_BASE_URL = "https://api.stepfun.com/v1"
POLL_INTERVAL = 2 # seconds
POLL_TIMEOUT = 60 # seconds
def get_api_token() -> str:
"""Get API token from environment variable"""
token = os.getenv('STEPFUN_API_TOKEN', '')
if not token:
logger.warning("⚠️ STEPFUN_API_TOKEN not set in environment variables")
return token
class StepFunAPIClient:
"""StepFun API Client for audio editing"""
@property
def token(self) -> str:
"""Get token dynamically (allows runtime env var changes)"""
return get_api_token()
@property
def base_headers(self) -> dict:
"""Get headers with current token"""
return {"Authorization": f"Bearer {self.token}"}
def upload_file(self, file_path: str) -> str | None:
"""Upload audio file and return file_id"""
url = f"{API_BASE_URL}/files"
file_name = os.path.basename(file_path)
# Determine audio type
if file_path.endswith('.wav'):
audio_type = 'audio/wav'
elif file_path.endswith('.mp3'):
audio_type = 'audio/mpeg'
else:
audio_type = 'audio/wav'
files = [
('file', (file_name, open(file_path, 'rb'), audio_type))
]
payload = {'purpose': 'storage'}
headers = self.base_headers.copy()
try:
response = requests.post(url, headers=headers, data=payload, files=files, timeout=60)
response.raise_for_status()
result = response.json()
file_id = result.get("id")
logger.info(f"✅ Uploaded file {file_name}, file_id: {file_id}")
return file_id
except Exception as e:
logger.error(f"❌ Failed to upload file {file_name}: {e}")
return None
def query_file_status(self, file_id: str) -> bool:
"""Query file status, return True if success"""
url = f"{API_BASE_URL}/files/{file_id}"
headers = {
**self.base_headers,
"Content-Type": "application/json"
}
try:
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
status = result.get("status", "").lower()
logger.debug(f"File {file_id} status: {status}")
return status == "success"
except Exception as e:
logger.error(f"❌ Failed to query file status: {e}")
return False
def wait_for_file_ready(self, file_id: str) -> bool:
"""Wait for file to be ready (poll status)"""
start_time = time.time()
while time.time() - start_time < POLL_TIMEOUT:
if self.query_file_status(file_id):
return True
time.sleep(POLL_INTERVAL)
logger.error(f"⏰ File {file_id} status query timeout ({POLL_TIMEOUT}s)")
return False
def audio_edit(self, file_id: str, sample_text: str, target_text: str,
edit_type: str = "clone", edit_info: str = "") -> bytes | None:
"""Call audio edit API and return audio bytes"""
url = f"{API_BASE_URL}/audio/edit"
headers = {
**self.base_headers,
"Content-Type": "application/json"
}
payload = {
"model": "step-tts-edit",
"file_id": file_id,
"sample_text": target_text, # target text from UI
"text": sample_text, # prompt text from UI
"edit_type": edit_type,
"edit_info": edit_info or "",
"response_format": "wav"
}
logger.info(f"🎯 Calling audio edit API: edit_type={edit_type}, edit_info={edit_info}")
try:
response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=120)
# Check if response is audio or error
content_type = response.headers.get('Content-Type', '')
if 'audio' in content_type or response.status_code == 200:
if len(response.content) > 1000: # Likely audio data
logger.info(f"✅ Audio edit successful, received {len(response.content)} bytes")
return response.content
else:
# Might be an error message
try:
error_data = response.json()
logger.error(f"❌ API error: {error_data}")
except:
logger.error(f"❌ API error: {response.text}")
return None
else:
logger.error(f"❌ API error: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"❌ Audio edit request failed: {e}")
return None
def transcribe_audio_sse(self, audio_path: str, progress_callback=None, streaming=False):
"""
使用ASR SSE接口转录音频文件
Args:
audio_path: 音频文件路径
progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text)
streaming: 是否返回生成器进行流式更新
Returns:
如果streaming=False: 完整的转录文本
如果streaming=True: 生成器,产生增量更新和最终文本
Raises:
Exception: 如果转录失败
"""
if streaming:
return self._transcribe_audio_sse_streaming(audio_path, progress_callback)
else:
return self._transcribe_audio_sse_sync(audio_path, progress_callback)
def _transcribe_audio_sse_sync(self, audio_path: str, progress_callback=None) -> str:
"""
同步ASR转录,返回最终文本
"""
url = f"{API_BASE_URL}/audio/asr/sse"
# 读取音频文件并转换为base64
try:
with open(audio_path, 'rb') as audio_file:
audio_data = base64.b64encode(audio_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"❌ Failed to read audio file: {e}")
raise Exception(f"Failed to read audio file: {e}")
# 构建请求payload
payload = {
"audio": {
"data": audio_data,
"input": {
"transcription": {
"language": "zh",
"prompt": "请记录下你所听到的语音内容。",
"model": "step-asr",
"full_rerun_on_commit": True,
"enable_itn": True
},
"format": {
"type": "pcm",
"codec": "pcm_s16le",
"rate": 16000,
"bits": 16,
"channel": 1
}
}
}
}
headers = {
**self.base_headers,
"Content-Type": "application/json",
"Accept": "text/event-stream"
}
logger.info("🎙️ Starting ASR transcription...")
try:
response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120)
response.raise_for_status()
final_text = ""
accumulated_text = ""
for line in response.iter_lines(decode_unicode=True):
if line:
line = line.strip()
if line.startswith("data: "):
try:
# 解析SSE数据
data_str = line[6:] # 去掉 "data: " 前缀
if data_str == "[DONE]":
break
event_data = json.loads(data_str)
event_type = event_data.get("type")
if event_type == "transcript.text.delta":
delta = event_data.get("delta", "")
accumulated_text += delta
logger.debug(f"📝 ASR delta: {delta}")
# 处理增量更新回调
if progress_callback:
progress_callback(accumulated_text)
elif event_type == "transcript.text.done":
final_text = event_data.get("text", accumulated_text)
logger.info(f"✅ ASR transcription complete: {final_text}")
break
elif event_type == "error":
error_msg = event_data.get("message", "Unknown ASR error")
logger.error(f"❌ ASR error: {error_msg}")
raise Exception(f"ASR API error: {error_msg}")
except json.JSONDecodeError as e:
logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}")
continue
except Exception as e:
logger.error(f"❌ Error processing SSE event: {e}")
raise
# 如果没有获得final_text,使用accumulated_text
if not final_text:
final_text = accumulated_text
if not final_text:
raise Exception("No transcription result received")
logger.info(f"🎯 Final transcription: {final_text}")
return final_text
except Exception as e:
logger.error(f"❌ ASR transcription failed: {e}")
raise Exception(f"ASR transcription failed: {e}")
def _transcribe_audio_sse_streaming(self, audio_path: str, progress_callback=None):
"""
流式ASR转录,返回生成器
"""
url = f"{API_BASE_URL}/audio/asr"
# 读取音频文件并转换为base64
try:
with open(audio_path, 'rb') as audio_file:
audio_data = base64.b64encode(audio_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"❌ Failed to read audio file: {e}")
yield f"[读取文件失败: {e}]"
return
# 构建请求payload
payload = {
"audio": {
"data": audio_data,
"input": {
"transcription": {
"language": "zh",
"prompt": "请记录下你所听到的语音内容。",
"model": "step-asr",
"full_rerun_on_commit": True,
"enable_itn": True
},
"format": {
"type": "pcm",
"codec": "pcm_s16le",
"rate": 16000,
"bits": 16,
"channel": 1
}
}
}
}
headers = {
**self.base_headers,
"Content-Type": "application/json",
"Accept": "text/event-stream"
}
logger.info("🎙️ Starting ASR transcription...")
try:
response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True, timeout=120)
response.raise_for_status()
final_text = ""
accumulated_text = ""
for line in response.iter_lines(decode_unicode=True):
if line:
line = line.strip()
if line.startswith("data: "):
try:
# 解析SSE数据
data_str = line[6:] # 去掉 "data: " 前缀
if data_str == "[DONE]":
break
event_data = json.loads(data_str)
event_type = event_data.get("type")
if event_type == "transcript.text.delta":
delta = event_data.get("delta", "")
accumulated_text += delta
logger.debug(f"📝 ASR delta: {delta}")
# 流式更新
yield accumulated_text
if progress_callback:
progress_callback(accumulated_text)
elif event_type == "transcript.text.done":
final_text = event_data.get("text", accumulated_text)
logger.info(f"✅ ASR transcription complete: {final_text}")
yield final_text
return
elif event_type == "error":
error_msg = event_data.get("message", "Unknown ASR error")
logger.error(f"❌ ASR error: {error_msg}")
yield f"[ASR错误: {error_msg}]"
return
except json.JSONDecodeError as e:
logger.warning(f"⚠️ Failed to parse SSE line: {line}, error: {e}")
continue
except Exception as e:
logger.error(f"❌ Error processing SSE event: {e}")
yield f"[处理错误: {str(e)}]"
return
# 如果没有获得final_text,使用accumulated_text
if not final_text:
final_text = accumulated_text
if not final_text:
yield f"[转录无结果]"
return
logger.info(f"🎯 Final transcription: {final_text}")
except Exception as e:
logger.error(f"❌ ASR transcription failed: {e}")
yield f"[转录失败: {str(e)}]"
# Global API client instance (singleton)
_client = StepFunAPIClient()
def get_client() -> StepFunAPIClient:
"""Get global API client"""
return _client
def process_audio(audio_input: str, prompt_text: str, target_text: str,
edit_type: str, edit_info: str = None) -> str:
"""
Process audio using StepFun API
Args:
audio_input: Path to input audio file
prompt_text: Text content of the input audio
target_text: Target text for generation
edit_type: Type of edit (clone, gender, age, etc.)
edit_info: Additional edit info (e.g., "male", "happy")
Returns:
Path to output audio file
Raises:
ValueError: If API token not configured
RuntimeError: If API call fails
"""
client = get_client()
if not client.token:
raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.")
# 1. Upload audio file
logger.info("📤 Uploading audio file...")
file_id = client.upload_file(audio_input)
if not file_id:
raise RuntimeError("Failed to upload audio file")
# 2. Wait for file to be ready
logger.info("⏳ Waiting for file to be ready...")
if not client.wait_for_file_ready(file_id):
raise RuntimeError("File processing timeout")
# 3. Call audio edit API
logger.info("🎤 Calling audio edit API...")
audio_bytes = client.audio_edit(
file_id=file_id,
sample_text=prompt_text,
target_text=target_text,
edit_type=edit_type,
edit_info=edit_info or ""
)
if not audio_bytes:
raise RuntimeError("Audio edit API returned no data")
# 4. Save to temp file and return path
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
output_path = f.name
logger.info(f"✅ Audio saved to: {output_path}")
return output_path
def transcribe_audio(audio_path: str, progress_callback=None, streaming=False):
"""
使用ASR SSE接口转录音频文件
Args:
audio_path: 音频文件路径
progress_callback: 可选的回调函数,用于处理增量文本更新 callback(delta_text)
streaming: 是否返回生成器进行流式更新
Returns:
如果streaming=False: 完整的转录文本
如果streaming=True: 生成器,产生增量更新和最终文本
Raises:
ValueError: If API token not configured
RuntimeError: If transcription fails
"""
if streaming:
return transcribe_audio_streaming(audio_path, progress_callback)
else:
return transcribe_audio_sync(audio_path, progress_callback)
def transcribe_audio_sync(audio_path: str, progress_callback=None) -> str:
"""
同步转录音频文件,返回最终文本
Args:
audio_path: 音频文件路径
progress_callback: 可选的回调函数,用于处理增量文本更新
Returns:
完整的转录文本
Raises:
ValueError: If API token not configured
RuntimeError: If transcription fails
"""
client = get_client()
if not client.token:
raise ValueError("API token not configured. Please set STEPFUN_API_TOKEN environment variable.")
try:
return client.transcribe_audio_sse(audio_path, progress_callback, streaming=False)
except Exception as e:
raise RuntimeError(f"Transcription failed: {e}")
def transcribe_audio_streaming(audio_path: str, progress_callback=None):
"""
流式转录音频文件,返回生成器
Args:
audio_path: 音频文件路径
progress_callback: 可选的回调函数,用于处理增量文本更新
Yields:
增量更新和最终文本
Raises:
ValueError: If API token not configured
RuntimeError: If transcription fails
"""
client = get_client()
if not client.token:
yield f"[错误: API token not configured]"
return
try:
for update in client.transcribe_audio_sse(audio_path, progress_callback, streaming=True):
yield update
except Exception as e:
yield f"[转录失败: {str(e)}]"