Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """FastAPI Server for Flow Agent — expose CLI functionality via HTTP/HTTPS. | |
| Allows n8n and other remote systems to trigger video/image generation and upload assets. | |
| """ | |
| import os | |
| import sys | |
| import uuid | |
| import time | |
| import shutil | |
| import base64 | |
| import logging | |
| import asyncio | |
| from typing import List, Optional | |
| from contextlib import asynccontextmanager | |
| # Add parent dir to sys.path so omniflash can be imported | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from omniflash import ( | |
| ExtensionBridge, generate_video, edit_video, | |
| poll_status, download_video, ASPECTS, DEFAULT_PROJECT, | |
| ) | |
| from omniflash.generators.i2v import upload_image, generate_video_i2v, generate_video_fl, generate_video_r2v | |
| from omniflash.generators.t2i import generate_image, download_image, IMAGE_ASPECTS | |
| from omniflash.upload import upload_video | |
| # Setup logging | |
| log = logging.getLogger("omniflash.api") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S") | |
| # Ensure required directories exist | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| OUTPUT_DIR = os.path.join(ROOT_DIR, "output") | |
| TEMP_DIR = os.path.join(OUTPUT_DIR, ".temp") | |
| def ensure_temp_dir(): | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| def cleanup_temp_dir(): | |
| try: | |
| if os.path.exists(TEMP_DIR) and not os.listdir(TEMP_DIR): | |
| os.rmdir(TEMP_DIR) | |
| except Exception: | |
| pass | |
| # Global ExtensionBridge instance | |
| bridge: Optional[ExtensionBridge] = None | |
| async def lifespan(app: FastAPI): | |
| global bridge | |
| log.info("🚀 Starting Flow Agent Extension Bridge...") | |
| bridge = ExtensionBridge() | |
| await bridge.start() | |
| # Run extension connection in background so the API server starts immediately | |
| asyncio.create_task(bridge.wait_for_extension(timeout=30)) | |
| yield | |
| log.info("🔌 Closing Flow Agent Extension Bridge...") | |
| if bridge: | |
| await bridge.close() | |
| cleanup_temp_dir() | |
| app = FastAPI( | |
| title="Flow Agent API", | |
| description="API Server to trigger Google Flow AI video and image generation", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Enable CORS for convenience | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Helper function to check/reconnect the bridge | |
| async def get_active_bridge() -> ExtensionBridge: | |
| global bridge | |
| if not bridge: | |
| raise HTTPException(status_code=503, detail="Extension bridge is not initialized") | |
| # Try a quick health check | |
| is_healthy = await bridge.health_check() | |
| if not is_healthy: | |
| log.info("🔄 Bridge health check failed. Re-waiting for extension connection...") | |
| # Attempt to reconnect / grab flowKey | |
| connected = await bridge.wait_for_extension(timeout=10, max_retries=1) | |
| if not connected: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Google Flow extension is not connected or unauthorized. Make sure Google Flow tab is open in Chrome." | |
| ) | |
| return bridge | |
| # Helper to process image inputs (local path, media_id, or base64 data) | |
| async def resolve_image_input(active_bridge: ExtensionBridge, path_or_id_or_b64: str, project_id: str) -> str: | |
| if not path_or_id_or_b64: | |
| return "" | |
| # Case 1: Base64 data (e.g. data:image/png;base64,... or raw base64) | |
| if path_or_id_or_b64.startswith("data:") or len(path_or_id_or_b64) > 500: | |
| try: | |
| if "," in path_or_id_or_b64: | |
| base64_data = path_or_id_or_b64.split(",", 1)[1] | |
| else: | |
| base64_data = path_or_id_or_b64 | |
| img_bytes = base64.b64decode(base64_data) | |
| temp_filename = f"b64_{uuid.uuid4().hex}.png" | |
| ensure_temp_dir() | |
| temp_path = os.path.join(TEMP_DIR, temp_filename) | |
| with open(temp_path, "wb") as f: | |
| f.write(img_bytes) | |
| mid = await upload_image(active_bridge, temp_path, project_id) | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| cleanup_temp_dir() | |
| if not mid: | |
| raise HTTPException(status_code=400, detail="Failed to upload base64 image reference") | |
| return mid | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Failed parsing base64 image: {str(e)}") | |
| # Case 2: Local file path | |
| if os.path.exists(path_or_id_or_b64): | |
| mid = await upload_image(active_bridge, path_or_id_or_b64, project_id) | |
| if not mid: | |
| raise HTTPException(status_code=400, detail=f"Failed to upload local image path: {path_or_id_or_b64}") | |
| return mid | |
| # Case 3: Already a Media ID (UUID or similar format) | |
| return path_or_id_or_b64 | |
| # Request Models | |
| class VideoGenerationRequest(BaseModel): | |
| prompt: str = Field(..., description="Text prompt for video generation") | |
| aspect: str = Field("portrait", description="Aspect ratio: 'portrait' or 'landscape'") | |
| duration: int = Field(10, description="Duration in seconds: 4, 6, 8, or 10") | |
| count: int = Field(1, description="Number of variations (1-4)") | |
| project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") | |
| start: Optional[str] = Field(None, description="Start frame image (file path, media_id, or base64)") | |
| end: Optional[str] = Field(None, description="End frame image (use with start for FL mode)") | |
| ref: Optional[List[str]] = Field(None, description="Reference image(s) (file path, media_id, or base64)") | |
| edit: Optional[str] = Field(None, description="Flow video media_id for video editing (V2V)") | |
| no_clean: bool = Field(False, description="Skip watermark removal") | |
| class ImageGenerationRequest(BaseModel): | |
| prompt: str = Field(..., description="Text prompt for image generation") | |
| aspect: str = Field("portrait", description="Aspect ratio: 'portrait', 'landscape', 'square', '4x3', '3x4'") | |
| count: int = Field(1, description="Number of variations (1-4)") | |
| ref: Optional[List[str]] = Field(None, description="Reference image(s) (file path, media_id, or base64)") | |
| project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") | |
| class VideoEditRequest(BaseModel): | |
| prompt: str = Field(..., description="Restyle/edit text prompt") | |
| video_media_id: str = Field(..., description="Original video media_id") | |
| aspect: str = Field("portrait", description="Aspect ratio: 'portrait' or 'landscape'") | |
| fps: int = Field(24, description="FPS of source video") | |
| duration: int = Field(10, description="Duration of segment to edit") | |
| start_frame: int = Field(0, description="Start frame index") | |
| end_frame: Optional[int] = Field(None, description="End frame index") | |
| project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") | |
| download: bool = Field(False, description="Directly download binary video stream") | |
| # API Routes | |
| async def health(): | |
| """Check API server connection and Chrome extension authorization.""" | |
| global bridge | |
| if not bridge: | |
| return {"status": "starting", "extension_connected": False, "has_flow_key": False} | |
| is_healthy = await bridge.health_check() | |
| return { | |
| "status": "healthy" if is_healthy else "disconnected_or_unauthorized", | |
| "extension_connected": bridge._ws is not None, | |
| "has_flow_key": bridge._flow_key is not None | |
| } | |
| async def api_upload_image( | |
| file: Optional[UploadFile] = File(None), | |
| path: Optional[str] = Form(None), | |
| project_id: str = Form(DEFAULT_PROJECT) | |
| ): | |
| """Upload an image to Google Flow. Accepts multipart file upload or local file path.""" | |
| active_bridge = await get_active_bridge() | |
| temp_path = None | |
| if file: | |
| temp_filename = f"upload_{uuid.uuid4().hex}_{file.filename}" | |
| ensure_temp_dir() | |
| temp_path = os.path.join(TEMP_DIR, temp_filename) | |
| with open(temp_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| upload_path = temp_path | |
| elif path: | |
| if not os.path.exists(path): | |
| raise HTTPException(status_code=404, detail=f"Local file not found: {path}") | |
| upload_path = path | |
| else: | |
| raise HTTPException(status_code=400, detail="Must provide 'file' (multipart) or 'path' (form parameter)") | |
| try: | |
| media_id = await upload_image(active_bridge, upload_path, project_id) | |
| if not media_id: | |
| raise HTTPException(status_code=500, detail="Flow image upload failed") | |
| return {"success": True, "media_id": media_id} | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| cleanup_temp_dir() | |
| async def api_upload_video( | |
| file: Optional[UploadFile] = File(None), | |
| path: Optional[str] = Form(None), | |
| project_id: str = Form(DEFAULT_PROJECT) | |
| ): | |
| """Upload a video to Google Flow. Accepts multipart file upload or local file path.""" | |
| active_bridge = await get_active_bridge() | |
| temp_path = None | |
| if file: | |
| temp_filename = f"upload_{uuid.uuid4().hex}_{file.filename}" | |
| ensure_temp_dir() | |
| temp_path = os.path.join(TEMP_DIR, temp_filename) | |
| with open(temp_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| upload_path = temp_path | |
| elif path: | |
| if not os.path.exists(path): | |
| raise HTTPException(status_code=404, detail=f"Local file not found: {path}") | |
| upload_path = path | |
| else: | |
| raise HTTPException(status_code=400, detail="Must provide 'file' (multipart) or 'path' (form parameter)") | |
| try: | |
| result = await upload_video(upload_path, project_id, active_bridge) | |
| media_id = result.get("mediaId") or result.get("name") or result.get("id") | |
| if not media_id and isinstance(result.get("media"), dict): | |
| media_id = result["media"].get("name") or result["media"].get("mediaId") | |
| if not media_id: | |
| raise HTTPException(status_code=500, detail=f"Flow video upload failed: {result}") | |
| return {"success": True, "media_id": media_id, "data": result} | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| cleanup_temp_dir() | |
| async def api_generate_video(req: VideoGenerationRequest, download: bool = Query(False)): | |
| """Generate or edit video via text prompt and optional references (T2V, I2V, FL, R2V, V2V).""" | |
| active_bridge = await get_active_bridge() | |
| aspect = ASPECTS.get(req.aspect, "VIDEO_ASPECT_RATIO_PORTRAIT") | |
| # 1. Resolve starting image (I2V / FL) | |
| start_id = None | |
| if req.start: | |
| start_id = await resolve_image_input(active_bridge, req.start, req.project_id) | |
| # 2. Resolve end image (FL) | |
| end_id = None | |
| if req.end: | |
| end_id = await resolve_image_input(active_bridge, req.end, req.project_id) | |
| # 3. Resolve reference images (R2V) | |
| ref_ids = [] | |
| if req.ref: | |
| for r in req.ref: | |
| mid = await resolve_image_input(active_bridge, r, req.project_id) | |
| if mid: | |
| ref_ids.append(mid) | |
| # 4. Trigger generation | |
| media_ids = None | |
| if start_id and end_id: | |
| media_ids = await generate_video_fl( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| start_image_id=start_id, end_image_id=end_id, duration=req.duration | |
| ) | |
| elif start_id: | |
| media_ids = await generate_video_i2v( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| image_media_id=start_id, duration=req.duration | |
| ) | |
| elif ref_ids: | |
| media_ids = await generate_video_r2v( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| ref_media_ids=ref_ids, duration=req.duration | |
| ) | |
| elif req.edit: | |
| media_ids = await edit_video( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| video_media_id=req.edit, duration=req.duration | |
| ) | |
| else: | |
| media_ids = await generate_video( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| duration=req.duration, count=req.count | |
| ) | |
| if not media_ids: | |
| raise HTTPException(status_code=500, detail="Failed to initiate video generation") | |
| outputs = [] | |
| timestamp = int(time.time()) | |
| # 5. Poll and Download | |
| for i, media_id in enumerate(media_ids): | |
| log.info(f"Polling video [{i+1}/{len(media_ids)}] ID: {media_id}") | |
| if not await poll_status(active_bridge, media_id, req.project_id): | |
| log.error(f"Polling failed for media ID: {media_id}") | |
| continue | |
| unique_id = uuid.uuid4().hex[:6] | |
| filename = f"omni_{timestamp}_{unique_id}_{i+1}.mp4" | |
| out_path = os.path.join(OUTPUT_DIR, filename) | |
| ensure_temp_dir() | |
| temp_path = os.path.join(TEMP_DIR, filename) | |
| if await download_video(active_bridge, media_id, temp_path): | |
| # Watermark removal disabled — save the raw downloaded video as-is. | |
| os.replace(temp_path, out_path) | |
| outputs.append({ | |
| "media_id": media_id, | |
| "filename": filename, | |
| "local_path": out_path, | |
| "download_url": f"/download/{filename}" | |
| }) | |
| cleanup_temp_dir() | |
| if not outputs: | |
| raise HTTPException(status_code=500, detail="Failed to download generated video(s)") | |
| # Return binary directly if requested and single file | |
| if download and len(outputs) == 1: | |
| return FileResponse( | |
| path=outputs[0]["local_path"], | |
| filename=outputs[0]["filename"], | |
| media_type="video/mp4" | |
| ) | |
| return { | |
| "success": True, | |
| "outputs": outputs | |
| } | |
| async def api_generate_image(req: ImageGenerationRequest, download: bool = Query(False)): | |
| """Generate image using text prompt and optional reference images (T2I, I2I).""" | |
| active_bridge = await get_active_bridge() | |
| aspect = req.aspect | |
| # Resolve reference images if any | |
| ref_ids = [] | |
| if req.ref: | |
| for r in req.ref: | |
| mid = await resolve_image_input(active_bridge, r, req.project_id) | |
| if mid: | |
| ref_ids.append(mid) | |
| results = await generate_image( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| count=req.count, ref_media_ids=ref_ids or None | |
| ) | |
| if not results: | |
| raise HTTPException(status_code=500, detail="Failed to generate image") | |
| outputs = [] | |
| timestamp = int(time.time()) | |
| for i, r in enumerate(results): | |
| url = r.get("image_url") | |
| media_id = r.get("media_id") | |
| if not url: | |
| continue | |
| unique_id = uuid.uuid4().hex[:6] | |
| filename = f"img_{timestamp}_{unique_id}_{i+1}.png" | |
| out_path = os.path.join(OUTPUT_DIR, filename) | |
| download_success = await download_image(active_bridge, url, out_path) | |
| outputs.append({ | |
| "media_id": media_id, | |
| "filename": filename, | |
| "local_path": out_path if download_success else None, | |
| "download_url": f"/download/{filename}" if download_success else None, | |
| "remote_url": url, | |
| "downloaded": download_success | |
| }) | |
| # Return binary directly if requested, single image, and it was successfully downloaded | |
| if download and len(outputs) == 1 and outputs[0]["downloaded"]: | |
| return FileResponse( | |
| path=outputs[0]["local_path"], | |
| filename=outputs[0]["filename"], | |
| media_type="image/png" | |
| ) | |
| return { | |
| "success": True, | |
| "outputs": outputs | |
| } | |
| async def api_edit_video(req: VideoEditRequest): | |
| """Submit V2V edit request.""" | |
| active_bridge = await get_active_bridge() | |
| aspect = ASPECTS.get(req.aspect, "VIDEO_ASPECT_RATIO_PORTRAIT") | |
| media_ids = await edit_video( | |
| active_bridge, req.prompt, aspect, req.project_id, | |
| video_media_id=req.video_media_id, fps=req.fps, | |
| duration=req.duration, start_frame=req.start_frame, | |
| end_frame=req.end_frame | |
| ) | |
| if not media_ids: | |
| raise HTTPException(status_code=500, detail="Failed to submit V2V edit request") | |
| outputs = [] | |
| timestamp = int(time.time()) | |
| for i, media_id in enumerate(media_ids): | |
| log.info(f"Polling edited video [{i+1}/{len(media_ids)}] ID: {media_id}") | |
| if not await poll_status(active_bridge, media_id, req.project_id): | |
| continue | |
| unique_id = uuid.uuid4().hex[:6] | |
| filename = f"edit_{timestamp}_{unique_id}_{i+1}.mp4" | |
| out_path = os.path.join(OUTPUT_DIR, filename) | |
| ensure_temp_dir() | |
| temp_path = os.path.join(TEMP_DIR, filename) | |
| if await download_video(active_bridge, media_id, temp_path): | |
| # Watermark removal disabled — save the raw downloaded video as-is. | |
| os.replace(temp_path, out_path) | |
| outputs.append({ | |
| "media_id": media_id, | |
| "filename": filename, | |
| "local_path": out_path, | |
| "download_url": f"/download/{filename}" | |
| }) | |
| cleanup_temp_dir() | |
| if not outputs: | |
| raise HTTPException(status_code=500, detail="Failed to download edited video(s)") | |
| if req.download and len(outputs) == 1: | |
| return FileResponse( | |
| path=outputs[0]["local_path"], | |
| filename=outputs[0]["filename"], | |
| media_type="video/mp4" | |
| ) | |
| return { | |
| "success": True, | |
| "outputs": outputs | |
| } | |
| async def api_download_file(filename: str): | |
| """Download generated assets from output folder.""" | |
| file_path = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Requested file not found") | |
| # Standardize content types | |
| media_type = "application/octet-stream" | |
| if filename.endswith(".mp4"): | |
| media_type = "video/mp4" | |
| elif filename.endswith(".png"): | |
| media_type = "image/png" | |
| elif filename.endswith(".jpg") or filename.endswith(".jpeg"): | |
| media_type = "image/jpeg" | |
| return FileResponse(path=file_path, filename=filename, media_type=media_type) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Flow Agent API Server") | |
| parser.add_argument("--host", default="127.0.0.1", help="Host address") | |
| parser.add_argument("--port", type=int, default=8000, help="Port to run on") | |
| parser.add_argument("--ssl", action="store_true", help="Enable self-signed SSL certificate") | |
| parser.add_argument("--ssl-certfile", help="SSL certificate file path") | |
| parser.add_argument("--ssl-keyfile", help="SSL private key file path") | |
| args = parser.parse_args() | |
| ssl_keyfile = args.ssl_keyfile | |
| ssl_certfile = args.ssl_certfile | |
| if args.ssl and not (ssl_keyfile and ssl_certfile): | |
| try: | |
| from cryptography import x509 | |
| from cryptography.x509.oid import NameOID | |
| from cryptography.hazmat.primitives import hashes | |
| from cryptography.hazmat.primitives.asymmetric import rsa | |
| from cryptography.hazmat.primitives import serialization | |
| import datetime | |
| # Generate RSA key | |
| key = rsa.generate_private_key( | |
| public_exponent=65537, | |
| key_size=2048, | |
| ) | |
| # Create self-signed cert info | |
| subject = issuer = x509.Name([ | |
| x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), | |
| ]) | |
| cert = x509.CertificateBuilder().subject_name( | |
| subject | |
| ).issuer_name( | |
| issuer | |
| ).public_key( | |
| key.public_key() | |
| ).serial_number( | |
| x509.random_serial_number() | |
| ).not_valid_before( | |
| datetime.datetime.utcnow() | |
| ).not_valid_after( | |
| datetime.datetime.utcnow() + datetime.timedelta(days=365) | |
| ).add_extension( | |
| x509.SubjectAlternativeName([x509.DNSName(u"localhost")]), | |
| critical=False, | |
| ).sign(key, hashes.SHA256()) | |
| ssl_dir = os.path.join(OUTPUT_DIR, ".ssl") | |
| os.makedirs(ssl_dir, exist_ok=True) | |
| ssl_keyfile = os.path.join(ssl_dir, "key.pem") | |
| ssl_certfile = os.path.join(ssl_dir, "cert.pem") | |
| with open(ssl_keyfile, "wb") as f: | |
| f.write(key.private_bytes( | |
| encoding=serialization.Encoding.PEM, | |
| format=serialization.PrivateFormat.TraditionalOpenSSL, | |
| encryption_algorithm=serialization.NoEncryption(), | |
| )) | |
| with open(ssl_certfile, "wb") as f: | |
| f.write(cert.public_bytes(serialization.Encoding.PEM)) | |
| log.info(f"🔒 Generated temporary self-signed SSL certificate in {ssl_dir}") | |
| except ImportError: | |
| log.warning("⚠️ cryptography package not found. Cannot auto-generate self-signed SSL cert.") | |
| log.warning(" Please install it: pip install cryptography") | |
| log.warning(" Falling back to standard HTTP.") | |
| args.ssl = False | |
| import uvicorn | |
| uvicorn.run( | |
| "cli.api:app", | |
| host=args.host, | |
| port=args.port, | |
| ssl_keyfile=ssl_keyfile if args.ssl or (args.ssl_keyfile and args.ssl_keyfile) else None, | |
| ssl_certfile=ssl_certfile if args.ssl or (args.ssl_keyfile and args.ssl_keyfile) else None, | |
| ) | |