Spaces:
Runtime error
Runtime error
| """ | |
| AI Media Studio - FastAPI REST API | |
| 极客基地 AI全球超级站点 | |
| RESTful API服务,支持: | |
| - AI图像生成 | |
| - AI视频生成 | |
| - AI音频生成 | |
| - 批量处理 | |
| - 任务队列 | |
| """ | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Dict, Any | |
| from pathlib import Path | |
| from datetime import datetime | |
| import uuid | |
| import os | |
| import sys | |
| import asyncio | |
| from enum import Enum | |
| # 添加项目根目录 | |
| ROOT_DIR = Path(__file__).parent.parent | |
| sys.path.insert(0, str(ROOT_DIR)) | |
| # ============================================ | |
| # 应用配置 | |
| # ============================================ | |
| app = FastAPI( | |
| title="AI Media Studio API", | |
| description=""" | |
| 🚀 极客基地 AI全球超级站点 - RESTful API | |
| ## 功能 | |
| * 🎨 **图像生成** - Stable Diffusion, SDXL, FLUX | |
| * 🎬 **视频生成** - AnimateDiff, SVD, CogVideoX | |
| * 🎵 **音频生成** - Bark TTS, MusicGen, Whisper | |
| * 📦 **批量处理** - 批量生成和处理任务 | |
| * 📊 **任务管理** - 异步任务队列 | |
| ## 认证 | |
| API密钥通过Header传递: `X-API-Key: your-api-key` | |
| """, | |
| version="2.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # CORS配置 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 输出目录 | |
| OUTPUT_DIR = Path("/tmp/ai_media_studio_output") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # 任务存储 | |
| tasks_db: Dict[str, Dict] = {} | |
| # ============================================ | |
| # 数据模型 | |
| # ============================================ | |
| class TaskStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class ImageGenerationRequest(BaseModel): | |
| """图像生成请求""" | |
| prompt: str = Field(..., description="提示词") | |
| negative_prompt: Optional[str] = Field("", description="负面提示词") | |
| model: str = Field("sdxl", description="模型名称") | |
| width: int = Field(1024, ge=256, le=2048, description="图像宽度") | |
| height: int = Field(1024, ge=256, le=2048, description="图像高度") | |
| num_inference_steps: int = Field(30, ge=1, le=100, description="推理步数") | |
| guidance_scale: float = Field(7.5, ge=0, le=20, description="引导系数") | |
| seed: Optional[int] = Field(None, description="随机种子") | |
| style: Optional[str] = Field(None, description="风格预设") | |
| num_images: int = Field(1, ge=1, le=4, description="生成数量") | |
| class Config: | |
| json_schema_extra = { | |
| "example": { | |
| "prompt": "a beautiful sunset over the ocean", | |
| "model": "sdxl", | |
| "width": 1024, | |
| "height": 1024, | |
| "num_inference_steps": 30, | |
| "style": "cinematic" | |
| } | |
| } | |
| class VideoGenerationRequest(BaseModel): | |
| """视频生成请求""" | |
| prompt: str = Field(..., description="提示词") | |
| negative_prompt: Optional[str] = Field("", description="负面提示词") | |
| model: str = Field("animatediff", description="模型名称") | |
| num_frames: int = Field(16, ge=8, le=64, description="帧数") | |
| width: int = Field(512, ge=256, le=1024, description="视频宽度") | |
| height: int = Field(512, ge=256, le=1024, description="视频高度") | |
| num_inference_steps: int = Field(25, ge=10, le=50, description="推理步数") | |
| guidance_scale: float = Field(7.5, ge=0, le=15, description="引导系数") | |
| fps: int = Field(8, ge=4, le=30, description="帧率") | |
| seed: Optional[int] = Field(None, description="随机种子") | |
| motion: Optional[str] = Field(None, description="运动类型") | |
| class TTSRequest(BaseModel): | |
| """文本转语音请求""" | |
| text: str = Field(..., description="输入文本") | |
| voice: str = Field("en_female", description="语音预设") | |
| model: str = Field("bark", description="模型名称") | |
| class MusicGenerationRequest(BaseModel): | |
| """音乐生成请求""" | |
| prompt: str = Field(..., description="音乐描述") | |
| duration: float = Field(15, ge=5, le=60, description="时长(秒)") | |
| model: str = Field("musicgen-small", description="模型名称") | |
| class STTRequest(BaseModel): | |
| """语音转文本请求""" | |
| language: Optional[str] = Field(None, description="语言代码") | |
| model: str = Field("whisper-base", description="模型名称") | |
| class BatchRequest(BaseModel): | |
| """批量处理请求""" | |
| prompts: List[str] = Field(..., description="提示词列表") | |
| model: str = Field("sdxl", description="模型名称") | |
| width: int = Field(1024, description="图像宽度") | |
| height: int = Field(1024, description="图像高度") | |
| class TaskResponse(BaseModel): | |
| """任务响应""" | |
| task_id: str | |
| status: TaskStatus | |
| message: str | |
| created_at: str | |
| result_url: Optional[str] = None | |
| # ============================================ | |
| # 辅助函数 | |
| # ============================================ | |
| def create_task(task_type: str) -> str: | |
| """创建新任务""" | |
| task_id = str(uuid.uuid4()) | |
| tasks_db[task_id] = { | |
| "id": task_id, | |
| "type": task_type, | |
| "status": TaskStatus.PENDING, | |
| "created_at": datetime.now().isoformat(), | |
| "updated_at": datetime.now().isoformat(), | |
| "result": None, | |
| "error": None, | |
| } | |
| return task_id | |
| def update_task(task_id: str, status: TaskStatus, result: Any = None, error: str = None): | |
| """更新任务状态""" | |
| if task_id in tasks_db: | |
| tasks_db[task_id]["status"] = status | |
| tasks_db[task_id]["updated_at"] = datetime.now().isoformat() | |
| if result: | |
| tasks_db[task_id]["result"] = result | |
| if error: | |
| tasks_db[task_id]["error"] = error | |
| # ============================================ | |
| # 后台任务 | |
| # ============================================ | |
| async def process_image_generation(task_id: str, request: ImageGenerationRequest): | |
| """后台处理图像生成""" | |
| try: | |
| update_task(task_id, TaskStatus.PROCESSING) | |
| # 导入生成器 | |
| from ai_generation import ImageGenerator, PromptEnhancer | |
| # 增强提示词 | |
| prompt = request.prompt | |
| negative_prompt = request.negative_prompt | |
| if request.style: | |
| prompt = PromptEnhancer.enhance_prompt(prompt, style=request.style) | |
| if not negative_prompt: | |
| negative_prompt = PromptEnhancer.get_negative_prompt() | |
| # 生成图像 | |
| generator = ImageGenerator(model_name=request.model) | |
| generator.load_model() | |
| output_path = OUTPUT_DIR / f"{task_id}.png" | |
| image = generator.generate( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=request.width, | |
| height=request.height, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale, | |
| seed=request.seed, | |
| num_images=request.num_images, | |
| output_path=str(output_path), | |
| ) | |
| update_task(task_id, TaskStatus.COMPLETED, result=str(output_path)) | |
| except Exception as e: | |
| update_task(task_id, TaskStatus.FAILED, error=str(e)) | |
| async def process_video_generation(task_id: str, request: VideoGenerationRequest): | |
| """后台处理视频生成""" | |
| try: | |
| update_task(task_id, TaskStatus.PROCESSING) | |
| from ai_generation import VideoGenerator, VideoPromptEnhancer | |
| prompt = request.prompt | |
| negative_prompt = request.negative_prompt | |
| if request.motion: | |
| prompt = VideoPromptEnhancer.enhance_prompt(prompt, motion=request.motion) | |
| if not negative_prompt: | |
| negative_prompt = VideoPromptEnhancer.get_negative_prompt() | |
| generator = VideoGenerator(model_name=request.model) | |
| generator.load_model() | |
| output_path = OUTPUT_DIR / f"{task_id}.mp4" | |
| frames = generator.generate_from_text( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=request.num_frames, | |
| width=request.width, | |
| height=request.height, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale, | |
| fps=request.fps, | |
| seed=request.seed, | |
| output_path=str(output_path), | |
| ) | |
| update_task(task_id, TaskStatus.COMPLETED, result=str(output_path)) | |
| except Exception as e: | |
| update_task(task_id, TaskStatus.FAILED, error=str(e)) | |
| async def process_tts(task_id: str, request: TTSRequest): | |
| """后台处理TTS""" | |
| try: | |
| update_task(task_id, TaskStatus.PROCESSING) | |
| from ai_generation import TextToSpeech | |
| tts = TextToSpeech(model_name=request.model) | |
| tts.load_model() | |
| output_path = OUTPUT_DIR / f"{task_id}.wav" | |
| audio, sr = tts.synthesize( | |
| request.text, | |
| voice=request.voice, | |
| output_path=str(output_path), | |
| ) | |
| update_task(task_id, TaskStatus.COMPLETED, result=str(output_path)) | |
| except Exception as e: | |
| update_task(task_id, TaskStatus.FAILED, error=str(e)) | |
| async def process_music_generation(task_id: str, request: MusicGenerationRequest): | |
| """后台处理音乐生成""" | |
| try: | |
| update_task(task_id, TaskStatus.PROCESSING) | |
| from ai_generation import MusicGenerator | |
| generator = MusicGenerator(model_name=request.model) | |
| generator.load_model() | |
| output_path = OUTPUT_DIR / f"{task_id}.wav" | |
| audio, sr = generator.generate( | |
| request.prompt, | |
| duration=request.duration, | |
| output_path=str(output_path), | |
| ) | |
| update_task(task_id, TaskStatus.COMPLETED, result=str(output_path)) | |
| except Exception as e: | |
| update_task(task_id, TaskStatus.FAILED, error=str(e)) | |
| # ============================================ | |
| # API端点 | |
| # ============================================ | |
| async def root(): | |
| """API根路径""" | |
| return { | |
| "name": "AI Media Studio API", | |
| "version": "2.0.0", | |
| "description": "极客基地 AI全球超级站点", | |
| "docs": "/docs", | |
| "endpoints": { | |
| "image": "/api/v1/image/generate", | |
| "video": "/api/v1/video/generate", | |
| "audio": "/api/v1/audio/tts", | |
| "music": "/api/v1/audio/music", | |
| "tasks": "/api/v1/tasks", | |
| } | |
| } | |
| async def health_check(): | |
| """健康检查""" | |
| return {"status": "healthy", "timestamp": datetime.now().isoformat()} | |
| # ============================================ | |
| # 图像生成API | |
| # ============================================ | |
| async def generate_image(request: ImageGenerationRequest, background_tasks: BackgroundTasks): | |
| """ | |
| 生成AI图像 | |
| 支持的模型: | |
| - sdxl: Stable Diffusion XL | |
| - sdxl-turbo: SDXL Turbo (快速) | |
| - sd-1.5: Stable Diffusion 1.5 | |
| - sd3: Stable Diffusion 3 | |
| - flux-schnell: FLUX Schnell | |
| """ | |
| task_id = create_task("image_generation") | |
| background_tasks.add_task(process_image_generation, task_id, request) | |
| return TaskResponse( | |
| task_id=task_id, | |
| status=TaskStatus.PENDING, | |
| message="图像生成任务已创建", | |
| created_at=tasks_db[task_id]["created_at"], | |
| ) | |
| async def list_image_models(): | |
| """列出可用的图像生成模型""" | |
| return { | |
| "models": [ | |
| {"id": "sd-1.5", "name": "Stable Diffusion 1.5", "description": "经典SD模型"}, | |
| {"id": "sd-2.1", "name": "Stable Diffusion 2.1", "description": "改进版SD"}, | |
| {"id": "sdxl", "name": "Stable Diffusion XL", "description": "高质量1024px"}, | |
| {"id": "sdxl-turbo", "name": "SDXL Turbo", "description": "快速生成"}, | |
| {"id": "sd3", "name": "Stable Diffusion 3", "description": "最新SD3"}, | |
| {"id": "flux-schnell", "name": "FLUX Schnell", "description": "快速高质量"}, | |
| {"id": "flux-dev", "name": "FLUX Dev", "description": "开发版"}, | |
| {"id": "playground-v2.5", "name": "Playground v2.5", "description": "美学优化"}, | |
| ] | |
| } | |
| async def list_image_styles(): | |
| """列出可用的风格预设""" | |
| return { | |
| "styles": [ | |
| {"id": "cinematic", "name": "电影感"}, | |
| {"id": "anime", "name": "动漫"}, | |
| {"id": "photorealistic", "name": "写实"}, | |
| {"id": "oil_painting", "name": "油画"}, | |
| {"id": "watercolor", "name": "水彩"}, | |
| {"id": "3d_render", "name": "3D渲染"}, | |
| {"id": "pixel_art", "name": "像素艺术"}, | |
| {"id": "concept_art", "name": "概念艺术"}, | |
| {"id": "fantasy", "name": "奇幻"}, | |
| {"id": "sci_fi", "name": "科幻"}, | |
| ] | |
| } | |
| # ============================================ | |
| # 视频生成API | |
| # ============================================ | |
| async def generate_video(request: VideoGenerationRequest, background_tasks: BackgroundTasks): | |
| """ | |
| 生成AI视频 | |
| 支持的模型: | |
| - animatediff: AnimateDiff | |
| - svd: Stable Video Diffusion | |
| - cogvideox-2b: CogVideoX 2B | |
| """ | |
| task_id = create_task("video_generation") | |
| background_tasks.add_task(process_video_generation, task_id, request) | |
| return TaskResponse( | |
| task_id=task_id, | |
| status=TaskStatus.PENDING, | |
| message="视频生成任务已创建", | |
| created_at=tasks_db[task_id]["created_at"], | |
| ) | |
| async def list_video_models(): | |
| """列出可用的视频生成模型""" | |
| return { | |
| "models": [ | |
| {"id": "animatediff", "name": "AnimateDiff", "description": "动画生成"}, | |
| {"id": "animatediff-v3", "name": "AnimateDiff v3", "description": "改进版"}, | |
| {"id": "svd", "name": "Stable Video Diffusion", "description": "图生视频"}, | |
| {"id": "svd-xt", "name": "SVD-XT", "description": "更长视频"}, | |
| {"id": "cogvideox-2b", "name": "CogVideoX 2B", "description": "文生视频"}, | |
| {"id": "cogvideox-5b", "name": "CogVideoX 5B", "description": "高质量"}, | |
| ] | |
| } | |
| # ============================================ | |
| # 音频生成API | |
| # ============================================ | |
| async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks): | |
| """ | |
| 文本转语音 | |
| 支持的语音: | |
| - en_female/en_male: 英语 | |
| - zh_female/zh_male: 中文 | |
| - ja_female/ja_male: 日语 | |
| """ | |
| task_id = create_task("tts") | |
| background_tasks.add_task(process_tts, task_id, request) | |
| return TaskResponse( | |
| task_id=task_id, | |
| status=TaskStatus.PENDING, | |
| message="TTS任务已创建", | |
| created_at=tasks_db[task_id]["created_at"], | |
| ) | |
| async def generate_music(request: MusicGenerationRequest, background_tasks: BackgroundTasks): | |
| """ | |
| 生成AI音乐 | |
| 支持的模型: | |
| - musicgen-small/medium/large: Meta MusicGen | |
| - audioldm2-music: AudioLDM2 | |
| """ | |
| task_id = create_task("music_generation") | |
| background_tasks.add_task(process_music_generation, task_id, request) | |
| return TaskResponse( | |
| task_id=task_id, | |
| status=TaskStatus.PENDING, | |
| message="音乐生成任务已创建", | |
| created_at=tasks_db[task_id]["created_at"], | |
| ) | |
| async def speech_to_text( | |
| file: UploadFile = File(...), | |
| language: Optional[str] = None, | |
| model: str = "whisper-base", | |
| ): | |
| """ | |
| 语音转文本 | |
| 支持的模型: | |
| - whisper-tiny/base/small/medium/large | |
| """ | |
| try: | |
| # 保存上传的文件 | |
| file_path = OUTPUT_DIR / f"stt_{uuid.uuid4()}{Path(file.filename).suffix}" | |
| with open(file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| from ai_generation import SpeechToText | |
| stt = SpeechToText(model_name=model) | |
| stt.load_model() | |
| result = stt.transcribe(str(file_path), language=language) | |
| # 清理临时文件 | |
| file_path.unlink() | |
| return {"text": result["text"], "language": language or "auto"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_voices(): | |
| """列出可用的语音预设""" | |
| return { | |
| "voices": [ | |
| {"id": "en_female", "name": "英语女声", "language": "en"}, | |
| {"id": "en_male", "name": "英语男声", "language": "en"}, | |
| {"id": "zh_female", "name": "中文女声", "language": "zh"}, | |
| {"id": "zh_male", "name": "中文男声", "language": "zh"}, | |
| {"id": "ja_female", "name": "日语女声", "language": "ja"}, | |
| {"id": "ja_male", "name": "日语男声", "language": "ja"}, | |
| {"id": "ko_female", "name": "韩语女声", "language": "ko"}, | |
| {"id": "ko_male", "name": "韩语男声", "language": "ko"}, | |
| ] | |
| } | |
| # ============================================ | |
| # 任务管理API | |
| # ============================================ | |
| async def get_task(task_id: str): | |
| """获取任务状态""" | |
| if task_id not in tasks_db: | |
| raise HTTPException(status_code=404, detail="任务不存在") | |
| task = tasks_db[task_id] | |
| response = { | |
| "task_id": task["id"], | |
| "type": task["type"], | |
| "status": task["status"], | |
| "created_at": task["created_at"], | |
| "updated_at": task["updated_at"], | |
| } | |
| if task["status"] == TaskStatus.COMPLETED and task["result"]: | |
| response["result_url"] = f"/api/v1/tasks/{task_id}/download" | |
| if task["error"]: | |
| response["error"] = task["error"] | |
| return response | |
| async def download_task_result(task_id: str): | |
| """下载任务结果""" | |
| if task_id not in tasks_db: | |
| raise HTTPException(status_code=404, detail="任务不存在") | |
| task = tasks_db[task_id] | |
| if task["status"] != TaskStatus.COMPLETED: | |
| raise HTTPException(status_code=400, detail="任务未完成") | |
| if not task["result"] or not Path(task["result"]).exists(): | |
| raise HTTPException(status_code=404, detail="结果文件不存在") | |
| return FileResponse( | |
| task["result"], | |
| filename=Path(task["result"]).name, | |
| ) | |
| async def list_tasks( | |
| status: Optional[TaskStatus] = None, | |
| limit: int = 20, | |
| ): | |
| """列出所有任务""" | |
| tasks = list(tasks_db.values()) | |
| if status: | |
| tasks = [t for t in tasks if t["status"] == status] | |
| # 按创建时间倒序 | |
| tasks.sort(key=lambda x: x["created_at"], reverse=True) | |
| return {"tasks": tasks[:limit], "total": len(tasks)} | |
| async def delete_task(task_id: str): | |
| """删除任务""" | |
| if task_id not in tasks_db: | |
| raise HTTPException(status_code=404, detail="任务不存在") | |
| task = tasks_db[task_id] | |
| # 删除结果文件 | |
| if task["result"] and Path(task["result"]).exists(): | |
| Path(task["result"]).unlink() | |
| del tasks_db[task_id] | |
| return {"message": "任务已删除"} | |
| # ============================================ | |
| # 批量处理API | |
| # ============================================ | |
| async def batch_generate_images(request: BatchRequest, background_tasks: BackgroundTasks): | |
| """ | |
| 批量生成图像 | |
| 提交多个提示词,返回批量任务ID | |
| """ | |
| batch_id = str(uuid.uuid4()) | |
| task_ids = [] | |
| for prompt in request.prompts: | |
| task_id = create_task("image_generation") | |
| task_ids.append(task_id) | |
| img_request = ImageGenerationRequest( | |
| prompt=prompt, | |
| model=request.model, | |
| width=request.width, | |
| height=request.height, | |
| ) | |
| background_tasks.add_task(process_image_generation, task_id, img_request) | |
| return { | |
| "batch_id": batch_id, | |
| "task_ids": task_ids, | |
| "total": len(task_ids), | |
| "message": f"已创建 {len(task_ids)} 个图像生成任务", | |
| } | |
| # ============================================ | |
| # 启动服务 | |
| # ============================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| ) | |