AI-Media-Studio / api /main.py
yiyang-8
🚀 极客基地升级: 完整的AI超级站点功能
6adb512
"""
AI Media Studio - FastAPI REST API
极客基地 AI全球超级站点
RESTful API服务,支持:
- AI图像生成
- AI视频生成
- AI音频生成
- 批量处理
- 任务队列
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from pathlib import Path
from datetime import datetime
import uuid
import os
import sys
import asyncio
from enum import Enum
# 添加项目根目录
ROOT_DIR = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT_DIR))
# ============================================
# 应用配置
# ============================================
app = FastAPI(
title="AI Media Studio API",
description="""
🚀 极客基地 AI全球超级站点 - RESTful API
## 功能
* 🎨 **图像生成** - Stable Diffusion, SDXL, FLUX
* 🎬 **视频生成** - AnimateDiff, SVD, CogVideoX
* 🎵 **音频生成** - Bark TTS, MusicGen, Whisper
* 📦 **批量处理** - 批量生成和处理任务
* 📊 **任务管理** - 异步任务队列
## 认证
API密钥通过Header传递: `X-API-Key: your-api-key`
""",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
# CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 输出目录
OUTPUT_DIR = Path("/tmp/ai_media_studio_output")
OUTPUT_DIR.mkdir(exist_ok=True)
# 任务存储
tasks_db: Dict[str, Dict] = {}
# ============================================
# 数据模型
# ============================================
class TaskStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class ImageGenerationRequest(BaseModel):
"""图像生成请求"""
prompt: str = Field(..., description="提示词")
negative_prompt: Optional[str] = Field("", description="负面提示词")
model: str = Field("sdxl", description="模型名称")
width: int = Field(1024, ge=256, le=2048, description="图像宽度")
height: int = Field(1024, ge=256, le=2048, description="图像高度")
num_inference_steps: int = Field(30, ge=1, le=100, description="推理步数")
guidance_scale: float = Field(7.5, ge=0, le=20, description="引导系数")
seed: Optional[int] = Field(None, description="随机种子")
style: Optional[str] = Field(None, description="风格预设")
num_images: int = Field(1, ge=1, le=4, description="生成数量")
class Config:
json_schema_extra = {
"example": {
"prompt": "a beautiful sunset over the ocean",
"model": "sdxl",
"width": 1024,
"height": 1024,
"num_inference_steps": 30,
"style": "cinematic"
}
}
class VideoGenerationRequest(BaseModel):
"""视频生成请求"""
prompt: str = Field(..., description="提示词")
negative_prompt: Optional[str] = Field("", description="负面提示词")
model: str = Field("animatediff", description="模型名称")
num_frames: int = Field(16, ge=8, le=64, description="帧数")
width: int = Field(512, ge=256, le=1024, description="视频宽度")
height: int = Field(512, ge=256, le=1024, description="视频高度")
num_inference_steps: int = Field(25, ge=10, le=50, description="推理步数")
guidance_scale: float = Field(7.5, ge=0, le=15, description="引导系数")
fps: int = Field(8, ge=4, le=30, description="帧率")
seed: Optional[int] = Field(None, description="随机种子")
motion: Optional[str] = Field(None, description="运动类型")
class TTSRequest(BaseModel):
"""文本转语音请求"""
text: str = Field(..., description="输入文本")
voice: str = Field("en_female", description="语音预设")
model: str = Field("bark", description="模型名称")
class MusicGenerationRequest(BaseModel):
"""音乐生成请求"""
prompt: str = Field(..., description="音乐描述")
duration: float = Field(15, ge=5, le=60, description="时长(秒)")
model: str = Field("musicgen-small", description="模型名称")
class STTRequest(BaseModel):
"""语音转文本请求"""
language: Optional[str] = Field(None, description="语言代码")
model: str = Field("whisper-base", description="模型名称")
class BatchRequest(BaseModel):
"""批量处理请求"""
prompts: List[str] = Field(..., description="提示词列表")
model: str = Field("sdxl", description="模型名称")
width: int = Field(1024, description="图像宽度")
height: int = Field(1024, description="图像高度")
class TaskResponse(BaseModel):
"""任务响应"""
task_id: str
status: TaskStatus
message: str
created_at: str
result_url: Optional[str] = None
# ============================================
# 辅助函数
# ============================================
def create_task(task_type: str) -> str:
"""创建新任务"""
task_id = str(uuid.uuid4())
tasks_db[task_id] = {
"id": task_id,
"type": task_type,
"status": TaskStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"result": None,
"error": None,
}
return task_id
def update_task(task_id: str, status: TaskStatus, result: Any = None, error: str = None):
"""更新任务状态"""
if task_id in tasks_db:
tasks_db[task_id]["status"] = status
tasks_db[task_id]["updated_at"] = datetime.now().isoformat()
if result:
tasks_db[task_id]["result"] = result
if error:
tasks_db[task_id]["error"] = error
# ============================================
# 后台任务
# ============================================
async def process_image_generation(task_id: str, request: ImageGenerationRequest):
"""后台处理图像生成"""
try:
update_task(task_id, TaskStatus.PROCESSING)
# 导入生成器
from ai_generation import ImageGenerator, PromptEnhancer
# 增强提示词
prompt = request.prompt
negative_prompt = request.negative_prompt
if request.style:
prompt = PromptEnhancer.enhance_prompt(prompt, style=request.style)
if not negative_prompt:
negative_prompt = PromptEnhancer.get_negative_prompt()
# 生成图像
generator = ImageGenerator(model_name=request.model)
generator.load_model()
output_path = OUTPUT_DIR / f"{task_id}.png"
image = generator.generate(
prompt=prompt,
negative_prompt=negative_prompt,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
seed=request.seed,
num_images=request.num_images,
output_path=str(output_path),
)
update_task(task_id, TaskStatus.COMPLETED, result=str(output_path))
except Exception as e:
update_task(task_id, TaskStatus.FAILED, error=str(e))
async def process_video_generation(task_id: str, request: VideoGenerationRequest):
"""后台处理视频生成"""
try:
update_task(task_id, TaskStatus.PROCESSING)
from ai_generation import VideoGenerator, VideoPromptEnhancer
prompt = request.prompt
negative_prompt = request.negative_prompt
if request.motion:
prompt = VideoPromptEnhancer.enhance_prompt(prompt, motion=request.motion)
if not negative_prompt:
negative_prompt = VideoPromptEnhancer.get_negative_prompt()
generator = VideoGenerator(model_name=request.model)
generator.load_model()
output_path = OUTPUT_DIR / f"{task_id}.mp4"
frames = generator.generate_from_text(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=request.num_frames,
width=request.width,
height=request.height,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
fps=request.fps,
seed=request.seed,
output_path=str(output_path),
)
update_task(task_id, TaskStatus.COMPLETED, result=str(output_path))
except Exception as e:
update_task(task_id, TaskStatus.FAILED, error=str(e))
async def process_tts(task_id: str, request: TTSRequest):
"""后台处理TTS"""
try:
update_task(task_id, TaskStatus.PROCESSING)
from ai_generation import TextToSpeech
tts = TextToSpeech(model_name=request.model)
tts.load_model()
output_path = OUTPUT_DIR / f"{task_id}.wav"
audio, sr = tts.synthesize(
request.text,
voice=request.voice,
output_path=str(output_path),
)
update_task(task_id, TaskStatus.COMPLETED, result=str(output_path))
except Exception as e:
update_task(task_id, TaskStatus.FAILED, error=str(e))
async def process_music_generation(task_id: str, request: MusicGenerationRequest):
"""后台处理音乐生成"""
try:
update_task(task_id, TaskStatus.PROCESSING)
from ai_generation import MusicGenerator
generator = MusicGenerator(model_name=request.model)
generator.load_model()
output_path = OUTPUT_DIR / f"{task_id}.wav"
audio, sr = generator.generate(
request.prompt,
duration=request.duration,
output_path=str(output_path),
)
update_task(task_id, TaskStatus.COMPLETED, result=str(output_path))
except Exception as e:
update_task(task_id, TaskStatus.FAILED, error=str(e))
# ============================================
# API端点
# ============================================
@app.get("/")
async def root():
"""API根路径"""
return {
"name": "AI Media Studio API",
"version": "2.0.0",
"description": "极客基地 AI全球超级站点",
"docs": "/docs",
"endpoints": {
"image": "/api/v1/image/generate",
"video": "/api/v1/video/generate",
"audio": "/api/v1/audio/tts",
"music": "/api/v1/audio/music",
"tasks": "/api/v1/tasks",
}
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
# ============================================
# 图像生成API
# ============================================
@app.post("/api/v1/image/generate", response_model=TaskResponse, tags=["Image"])
async def generate_image(request: ImageGenerationRequest, background_tasks: BackgroundTasks):
"""
生成AI图像
支持的模型:
- sdxl: Stable Diffusion XL
- sdxl-turbo: SDXL Turbo (快速)
- sd-1.5: Stable Diffusion 1.5
- sd3: Stable Diffusion 3
- flux-schnell: FLUX Schnell
"""
task_id = create_task("image_generation")
background_tasks.add_task(process_image_generation, task_id, request)
return TaskResponse(
task_id=task_id,
status=TaskStatus.PENDING,
message="图像生成任务已创建",
created_at=tasks_db[task_id]["created_at"],
)
@app.get("/api/v1/image/models", tags=["Image"])
async def list_image_models():
"""列出可用的图像生成模型"""
return {
"models": [
{"id": "sd-1.5", "name": "Stable Diffusion 1.5", "description": "经典SD模型"},
{"id": "sd-2.1", "name": "Stable Diffusion 2.1", "description": "改进版SD"},
{"id": "sdxl", "name": "Stable Diffusion XL", "description": "高质量1024px"},
{"id": "sdxl-turbo", "name": "SDXL Turbo", "description": "快速生成"},
{"id": "sd3", "name": "Stable Diffusion 3", "description": "最新SD3"},
{"id": "flux-schnell", "name": "FLUX Schnell", "description": "快速高质量"},
{"id": "flux-dev", "name": "FLUX Dev", "description": "开发版"},
{"id": "playground-v2.5", "name": "Playground v2.5", "description": "美学优化"},
]
}
@app.get("/api/v1/image/styles", tags=["Image"])
async def list_image_styles():
"""列出可用的风格预设"""
return {
"styles": [
{"id": "cinematic", "name": "电影感"},
{"id": "anime", "name": "动漫"},
{"id": "photorealistic", "name": "写实"},
{"id": "oil_painting", "name": "油画"},
{"id": "watercolor", "name": "水彩"},
{"id": "3d_render", "name": "3D渲染"},
{"id": "pixel_art", "name": "像素艺术"},
{"id": "concept_art", "name": "概念艺术"},
{"id": "fantasy", "name": "奇幻"},
{"id": "sci_fi", "name": "科幻"},
]
}
# ============================================
# 视频生成API
# ============================================
@app.post("/api/v1/video/generate", response_model=TaskResponse, tags=["Video"])
async def generate_video(request: VideoGenerationRequest, background_tasks: BackgroundTasks):
"""
生成AI视频
支持的模型:
- animatediff: AnimateDiff
- svd: Stable Video Diffusion
- cogvideox-2b: CogVideoX 2B
"""
task_id = create_task("video_generation")
background_tasks.add_task(process_video_generation, task_id, request)
return TaskResponse(
task_id=task_id,
status=TaskStatus.PENDING,
message="视频生成任务已创建",
created_at=tasks_db[task_id]["created_at"],
)
@app.get("/api/v1/video/models", tags=["Video"])
async def list_video_models():
"""列出可用的视频生成模型"""
return {
"models": [
{"id": "animatediff", "name": "AnimateDiff", "description": "动画生成"},
{"id": "animatediff-v3", "name": "AnimateDiff v3", "description": "改进版"},
{"id": "svd", "name": "Stable Video Diffusion", "description": "图生视频"},
{"id": "svd-xt", "name": "SVD-XT", "description": "更长视频"},
{"id": "cogvideox-2b", "name": "CogVideoX 2B", "description": "文生视频"},
{"id": "cogvideox-5b", "name": "CogVideoX 5B", "description": "高质量"},
]
}
# ============================================
# 音频生成API
# ============================================
@app.post("/api/v1/audio/tts", response_model=TaskResponse, tags=["Audio"])
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
"""
文本转语音
支持的语音:
- en_female/en_male: 英语
- zh_female/zh_male: 中文
- ja_female/ja_male: 日语
"""
task_id = create_task("tts")
background_tasks.add_task(process_tts, task_id, request)
return TaskResponse(
task_id=task_id,
status=TaskStatus.PENDING,
message="TTS任务已创建",
created_at=tasks_db[task_id]["created_at"],
)
@app.post("/api/v1/audio/music", response_model=TaskResponse, tags=["Audio"])
async def generate_music(request: MusicGenerationRequest, background_tasks: BackgroundTasks):
"""
生成AI音乐
支持的模型:
- musicgen-small/medium/large: Meta MusicGen
- audioldm2-music: AudioLDM2
"""
task_id = create_task("music_generation")
background_tasks.add_task(process_music_generation, task_id, request)
return TaskResponse(
task_id=task_id,
status=TaskStatus.PENDING,
message="音乐生成任务已创建",
created_at=tasks_db[task_id]["created_at"],
)
@app.post("/api/v1/audio/stt", tags=["Audio"])
async def speech_to_text(
file: UploadFile = File(...),
language: Optional[str] = None,
model: str = "whisper-base",
):
"""
语音转文本
支持的模型:
- whisper-tiny/base/small/medium/large
"""
try:
# 保存上传的文件
file_path = OUTPUT_DIR / f"stt_{uuid.uuid4()}{Path(file.filename).suffix}"
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
from ai_generation import SpeechToText
stt = SpeechToText(model_name=model)
stt.load_model()
result = stt.transcribe(str(file_path), language=language)
# 清理临时文件
file_path.unlink()
return {"text": result["text"], "language": language or "auto"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/v1/audio/voices", tags=["Audio"])
async def list_voices():
"""列出可用的语音预设"""
return {
"voices": [
{"id": "en_female", "name": "英语女声", "language": "en"},
{"id": "en_male", "name": "英语男声", "language": "en"},
{"id": "zh_female", "name": "中文女声", "language": "zh"},
{"id": "zh_male", "name": "中文男声", "language": "zh"},
{"id": "ja_female", "name": "日语女声", "language": "ja"},
{"id": "ja_male", "name": "日语男声", "language": "ja"},
{"id": "ko_female", "name": "韩语女声", "language": "ko"},
{"id": "ko_male", "name": "韩语男声", "language": "ko"},
]
}
# ============================================
# 任务管理API
# ============================================
@app.get("/api/v1/tasks/{task_id}", tags=["Tasks"])
async def get_task(task_id: str):
"""获取任务状态"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="任务不存在")
task = tasks_db[task_id]
response = {
"task_id": task["id"],
"type": task["type"],
"status": task["status"],
"created_at": task["created_at"],
"updated_at": task["updated_at"],
}
if task["status"] == TaskStatus.COMPLETED and task["result"]:
response["result_url"] = f"/api/v1/tasks/{task_id}/download"
if task["error"]:
response["error"] = task["error"]
return response
@app.get("/api/v1/tasks/{task_id}/download", tags=["Tasks"])
async def download_task_result(task_id: str):
"""下载任务结果"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="任务不存在")
task = tasks_db[task_id]
if task["status"] != TaskStatus.COMPLETED:
raise HTTPException(status_code=400, detail="任务未完成")
if not task["result"] or not Path(task["result"]).exists():
raise HTTPException(status_code=404, detail="结果文件不存在")
return FileResponse(
task["result"],
filename=Path(task["result"]).name,
)
@app.get("/api/v1/tasks", tags=["Tasks"])
async def list_tasks(
status: Optional[TaskStatus] = None,
limit: int = 20,
):
"""列出所有任务"""
tasks = list(tasks_db.values())
if status:
tasks = [t for t in tasks if t["status"] == status]
# 按创建时间倒序
tasks.sort(key=lambda x: x["created_at"], reverse=True)
return {"tasks": tasks[:limit], "total": len(tasks)}
@app.delete("/api/v1/tasks/{task_id}", tags=["Tasks"])
async def delete_task(task_id: str):
"""删除任务"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="任务不存在")
task = tasks_db[task_id]
# 删除结果文件
if task["result"] and Path(task["result"]).exists():
Path(task["result"]).unlink()
del tasks_db[task_id]
return {"message": "任务已删除"}
# ============================================
# 批量处理API
# ============================================
@app.post("/api/v1/batch/images", tags=["Batch"])
async def batch_generate_images(request: BatchRequest, background_tasks: BackgroundTasks):
"""
批量生成图像
提交多个提示词,返回批量任务ID
"""
batch_id = str(uuid.uuid4())
task_ids = []
for prompt in request.prompts:
task_id = create_task("image_generation")
task_ids.append(task_id)
img_request = ImageGenerationRequest(
prompt=prompt,
model=request.model,
width=request.width,
height=request.height,
)
background_tasks.add_task(process_image_generation, task_id, img_request)
return {
"batch_id": batch_id,
"task_ids": task_ids,
"total": len(task_ids),
"message": f"已创建 {len(task_ids)} 个图像生成任务",
}
# ============================================
# 启动服务
# ============================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
)