#!/usr/bin/env python3 """FastAPI Server for Flow Agent — expose CLI functionality via HTTP/HTTPS. Allows n8n and other remote systems to trigger video/image generation and upload assets. """ import os import sys import uuid import time import shutil import base64 import logging import asyncio from typing import List, Optional from contextlib import asynccontextmanager # Add parent dir to sys.path so omniflash can be imported sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from omniflash import ( ExtensionBridge, generate_video, edit_video, poll_status, download_video, ASPECTS, DEFAULT_PROJECT, ) from omniflash.generators.i2v import upload_image, generate_video_i2v, generate_video_fl, generate_video_r2v from omniflash.generators.t2i import generate_image, download_image, IMAGE_ASPECTS from omniflash.upload import upload_video # Setup logging log = logging.getLogger("omniflash.api") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S") # Ensure required directories exist ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) OUTPUT_DIR = os.path.join(ROOT_DIR, "output") TEMP_DIR = os.path.join(OUTPUT_DIR, ".temp") def ensure_temp_dir(): os.makedirs(TEMP_DIR, exist_ok=True) def cleanup_temp_dir(): try: if os.path.exists(TEMP_DIR) and not os.listdir(TEMP_DIR): os.rmdir(TEMP_DIR) except Exception: pass # Global ExtensionBridge instance bridge: Optional[ExtensionBridge] = None @asynccontextmanager async def lifespan(app: FastAPI): global bridge log.info("🚀 Starting Flow Agent Extension Bridge...") bridge = ExtensionBridge() await bridge.start() # Run extension connection in background so the API server starts immediately asyncio.create_task(bridge.wait_for_extension(timeout=30)) yield log.info("🔌 Closing Flow Agent Extension Bridge...") if bridge: await bridge.close() cleanup_temp_dir() app = FastAPI( title="Flow Agent API", description="API Server to trigger Google Flow AI video and image generation", version="1.0.0", lifespan=lifespan ) # Enable CORS for convenience app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Helper function to check/reconnect the bridge async def get_active_bridge() -> ExtensionBridge: global bridge if not bridge: raise HTTPException(status_code=503, detail="Extension bridge is not initialized") # Try a quick health check is_healthy = await bridge.health_check() if not is_healthy: log.info("🔄 Bridge health check failed. Re-waiting for extension connection...") # Attempt to reconnect / grab flowKey connected = await bridge.wait_for_extension(timeout=10, max_retries=1) if not connected: raise HTTPException( status_code=503, detail="Google Flow extension is not connected or unauthorized. Make sure Google Flow tab is open in Chrome." ) return bridge # Helper to process image inputs (local path, media_id, or base64 data) async def resolve_image_input(active_bridge: ExtensionBridge, path_or_id_or_b64: str, project_id: str) -> str: if not path_or_id_or_b64: return "" # Case 1: Base64 data (e.g. data:image/png;base64,... or raw base64) if path_or_id_or_b64.startswith("data:") or len(path_or_id_or_b64) > 500: try: if "," in path_or_id_or_b64: base64_data = path_or_id_or_b64.split(",", 1)[1] else: base64_data = path_or_id_or_b64 img_bytes = base64.b64decode(base64_data) temp_filename = f"b64_{uuid.uuid4().hex}.png" ensure_temp_dir() temp_path = os.path.join(TEMP_DIR, temp_filename) with open(temp_path, "wb") as f: f.write(img_bytes) mid = await upload_image(active_bridge, temp_path, project_id) try: os.remove(temp_path) except OSError: pass cleanup_temp_dir() if not mid: raise HTTPException(status_code=400, detail="Failed to upload base64 image reference") return mid except Exception as e: raise HTTPException(status_code=400, detail=f"Failed parsing base64 image: {str(e)}") # Case 2: Local file path if os.path.exists(path_or_id_or_b64): mid = await upload_image(active_bridge, path_or_id_or_b64, project_id) if not mid: raise HTTPException(status_code=400, detail=f"Failed to upload local image path: {path_or_id_or_b64}") return mid # Case 3: Already a Media ID (UUID or similar format) return path_or_id_or_b64 # Request Models class VideoGenerationRequest(BaseModel): prompt: str = Field(..., description="Text prompt for video generation") aspect: str = Field("portrait", description="Aspect ratio: 'portrait' or 'landscape'") duration: int = Field(10, description="Duration in seconds: 4, 6, 8, or 10") count: int = Field(1, description="Number of variations (1-4)") project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") start: Optional[str] = Field(None, description="Start frame image (file path, media_id, or base64)") end: Optional[str] = Field(None, description="End frame image (use with start for FL mode)") ref: Optional[List[str]] = Field(None, description="Reference image(s) (file path, media_id, or base64)") edit: Optional[str] = Field(None, description="Flow video media_id for video editing (V2V)") no_clean: bool = Field(False, description="Skip watermark removal") class ImageGenerationRequest(BaseModel): prompt: str = Field(..., description="Text prompt for image generation") aspect: str = Field("portrait", description="Aspect ratio: 'portrait', 'landscape', 'square', '4x3', '3x4'") count: int = Field(1, description="Number of variations (1-4)") ref: Optional[List[str]] = Field(None, description="Reference image(s) (file path, media_id, or base64)") project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") class VideoEditRequest(BaseModel): prompt: str = Field(..., description="Restyle/edit text prompt") video_media_id: str = Field(..., description="Original video media_id") aspect: str = Field("portrait", description="Aspect ratio: 'portrait' or 'landscape'") fps: int = Field(24, description="FPS of source video") duration: int = Field(10, description="Duration of segment to edit") start_frame: int = Field(0, description="Start frame index") end_frame: Optional[int] = Field(None, description="End frame index") project_id: str = Field(DEFAULT_PROJECT, description="Flow project ID") download: bool = Field(False, description="Directly download binary video stream") # API Routes @app.get("/health") async def health(): """Check API server connection and Chrome extension authorization.""" global bridge if not bridge: return {"status": "starting", "extension_connected": False, "has_flow_key": False} is_healthy = await bridge.health_check() return { "status": "healthy" if is_healthy else "disconnected_or_unauthorized", "extension_connected": bridge._ws is not None, "has_flow_key": bridge._flow_key is not None } @app.post("/upload/image") async def api_upload_image( file: Optional[UploadFile] = File(None), path: Optional[str] = Form(None), project_id: str = Form(DEFAULT_PROJECT) ): """Upload an image to Google Flow. Accepts multipart file upload or local file path.""" active_bridge = await get_active_bridge() temp_path = None if file: temp_filename = f"upload_{uuid.uuid4().hex}_{file.filename}" ensure_temp_dir() temp_path = os.path.join(TEMP_DIR, temp_filename) with open(temp_path, "wb") as f: shutil.copyfileobj(file.file, f) upload_path = temp_path elif path: if not os.path.exists(path): raise HTTPException(status_code=404, detail=f"Local file not found: {path}") upload_path = path else: raise HTTPException(status_code=400, detail="Must provide 'file' (multipart) or 'path' (form parameter)") try: media_id = await upload_image(active_bridge, upload_path, project_id) if not media_id: raise HTTPException(status_code=500, detail="Flow image upload failed") return {"success": True, "media_id": media_id} finally: if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) except OSError: pass cleanup_temp_dir() @app.post("/upload/video") async def api_upload_video( file: Optional[UploadFile] = File(None), path: Optional[str] = Form(None), project_id: str = Form(DEFAULT_PROJECT) ): """Upload a video to Google Flow. Accepts multipart file upload or local file path.""" active_bridge = await get_active_bridge() temp_path = None if file: temp_filename = f"upload_{uuid.uuid4().hex}_{file.filename}" ensure_temp_dir() temp_path = os.path.join(TEMP_DIR, temp_filename) with open(temp_path, "wb") as f: shutil.copyfileobj(file.file, f) upload_path = temp_path elif path: if not os.path.exists(path): raise HTTPException(status_code=404, detail=f"Local file not found: {path}") upload_path = path else: raise HTTPException(status_code=400, detail="Must provide 'file' (multipart) or 'path' (form parameter)") try: result = await upload_video(upload_path, project_id, active_bridge) media_id = result.get("mediaId") or result.get("name") or result.get("id") if not media_id and isinstance(result.get("media"), dict): media_id = result["media"].get("name") or result["media"].get("mediaId") if not media_id: raise HTTPException(status_code=500, detail=f"Flow video upload failed: {result}") return {"success": True, "media_id": media_id, "data": result} finally: if temp_path and os.path.exists(temp_path): try: os.remove(temp_path) except OSError: pass cleanup_temp_dir() @app.post("/generate/video") async def api_generate_video(req: VideoGenerationRequest, download: bool = Query(False)): """Generate or edit video via text prompt and optional references (T2V, I2V, FL, R2V, V2V).""" active_bridge = await get_active_bridge() aspect = ASPECTS.get(req.aspect, "VIDEO_ASPECT_RATIO_PORTRAIT") # 1. Resolve starting image (I2V / FL) start_id = None if req.start: start_id = await resolve_image_input(active_bridge, req.start, req.project_id) # 2. Resolve end image (FL) end_id = None if req.end: end_id = await resolve_image_input(active_bridge, req.end, req.project_id) # 3. Resolve reference images (R2V) ref_ids = [] if req.ref: for r in req.ref: mid = await resolve_image_input(active_bridge, r, req.project_id) if mid: ref_ids.append(mid) # 4. Trigger generation media_ids = None if start_id and end_id: media_ids = await generate_video_fl( active_bridge, req.prompt, aspect, req.project_id, start_image_id=start_id, end_image_id=end_id, duration=req.duration ) elif start_id: media_ids = await generate_video_i2v( active_bridge, req.prompt, aspect, req.project_id, image_media_id=start_id, duration=req.duration ) elif ref_ids: media_ids = await generate_video_r2v( active_bridge, req.prompt, aspect, req.project_id, ref_media_ids=ref_ids, duration=req.duration ) elif req.edit: media_ids = await edit_video( active_bridge, req.prompt, aspect, req.project_id, video_media_id=req.edit, duration=req.duration ) else: media_ids = await generate_video( active_bridge, req.prompt, aspect, req.project_id, duration=req.duration, count=req.count ) if not media_ids: raise HTTPException(status_code=500, detail="Failed to initiate video generation") outputs = [] timestamp = int(time.time()) # 5. Poll and Download for i, media_id in enumerate(media_ids): log.info(f"Polling video [{i+1}/{len(media_ids)}] ID: {media_id}") if not await poll_status(active_bridge, media_id, req.project_id): log.error(f"Polling failed for media ID: {media_id}") continue unique_id = uuid.uuid4().hex[:6] filename = f"omni_{timestamp}_{unique_id}_{i+1}.mp4" out_path = os.path.join(OUTPUT_DIR, filename) ensure_temp_dir() temp_path = os.path.join(TEMP_DIR, filename) if await download_video(active_bridge, media_id, temp_path): # Watermark removal disabled — save the raw downloaded video as-is. os.replace(temp_path, out_path) outputs.append({ "media_id": media_id, "filename": filename, "local_path": out_path, "download_url": f"/download/{filename}" }) cleanup_temp_dir() if not outputs: raise HTTPException(status_code=500, detail="Failed to download generated video(s)") # Return binary directly if requested and single file if download and len(outputs) == 1: return FileResponse( path=outputs[0]["local_path"], filename=outputs[0]["filename"], media_type="video/mp4" ) return { "success": True, "outputs": outputs } @app.post("/generate/image") async def api_generate_image(req: ImageGenerationRequest, download: bool = Query(False)): """Generate image using text prompt and optional reference images (T2I, I2I).""" active_bridge = await get_active_bridge() aspect = req.aspect # Resolve reference images if any ref_ids = [] if req.ref: for r in req.ref: mid = await resolve_image_input(active_bridge, r, req.project_id) if mid: ref_ids.append(mid) results = await generate_image( active_bridge, req.prompt, aspect, req.project_id, count=req.count, ref_media_ids=ref_ids or None ) if not results: raise HTTPException(status_code=500, detail="Failed to generate image") outputs = [] timestamp = int(time.time()) for i, r in enumerate(results): url = r.get("image_url") media_id = r.get("media_id") if not url: continue unique_id = uuid.uuid4().hex[:6] filename = f"img_{timestamp}_{unique_id}_{i+1}.png" out_path = os.path.join(OUTPUT_DIR, filename) download_success = await download_image(active_bridge, url, out_path) outputs.append({ "media_id": media_id, "filename": filename, "local_path": out_path if download_success else None, "download_url": f"/download/{filename}" if download_success else None, "remote_url": url, "downloaded": download_success }) # Return binary directly if requested, single image, and it was successfully downloaded if download and len(outputs) == 1 and outputs[0]["downloaded"]: return FileResponse( path=outputs[0]["local_path"], filename=outputs[0]["filename"], media_type="image/png" ) return { "success": True, "outputs": outputs } @app.post("/edit/video") async def api_edit_video(req: VideoEditRequest): """Submit V2V edit request.""" active_bridge = await get_active_bridge() aspect = ASPECTS.get(req.aspect, "VIDEO_ASPECT_RATIO_PORTRAIT") media_ids = await edit_video( active_bridge, req.prompt, aspect, req.project_id, video_media_id=req.video_media_id, fps=req.fps, duration=req.duration, start_frame=req.start_frame, end_frame=req.end_frame ) if not media_ids: raise HTTPException(status_code=500, detail="Failed to submit V2V edit request") outputs = [] timestamp = int(time.time()) for i, media_id in enumerate(media_ids): log.info(f"Polling edited video [{i+1}/{len(media_ids)}] ID: {media_id}") if not await poll_status(active_bridge, media_id, req.project_id): continue unique_id = uuid.uuid4().hex[:6] filename = f"edit_{timestamp}_{unique_id}_{i+1}.mp4" out_path = os.path.join(OUTPUT_DIR, filename) ensure_temp_dir() temp_path = os.path.join(TEMP_DIR, filename) if await download_video(active_bridge, media_id, temp_path): # Watermark removal disabled — save the raw downloaded video as-is. os.replace(temp_path, out_path) outputs.append({ "media_id": media_id, "filename": filename, "local_path": out_path, "download_url": f"/download/{filename}" }) cleanup_temp_dir() if not outputs: raise HTTPException(status_code=500, detail="Failed to download edited video(s)") if req.download and len(outputs) == 1: return FileResponse( path=outputs[0]["local_path"], filename=outputs[0]["filename"], media_type="video/mp4" ) return { "success": True, "outputs": outputs } @app.get("/download/{filename}") async def api_download_file(filename: str): """Download generated assets from output folder.""" file_path = os.path.join(OUTPUT_DIR, filename) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="Requested file not found") # Standardize content types media_type = "application/octet-stream" if filename.endswith(".mp4"): media_type = "video/mp4" elif filename.endswith(".png"): media_type = "image/png" elif filename.endswith(".jpg") or filename.endswith(".jpeg"): media_type = "image/jpeg" return FileResponse(path=file_path, filename=filename, media_type=media_type) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Flow Agent API Server") parser.add_argument("--host", default="127.0.0.1", help="Host address") parser.add_argument("--port", type=int, default=8000, help="Port to run on") parser.add_argument("--ssl", action="store_true", help="Enable self-signed SSL certificate") parser.add_argument("--ssl-certfile", help="SSL certificate file path") parser.add_argument("--ssl-keyfile", help="SSL private key file path") args = parser.parse_args() ssl_keyfile = args.ssl_keyfile ssl_certfile = args.ssl_certfile if args.ssl and not (ssl_keyfile and ssl_certfile): try: from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization import datetime # Generate RSA key key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) # Create self-signed cert info subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( key.public_key() ).serial_number( x509.random_serial_number() ).not_valid_before( datetime.datetime.utcnow() ).not_valid_after( datetime.datetime.utcnow() + datetime.timedelta(days=365) ).add_extension( x509.SubjectAlternativeName([x509.DNSName(u"localhost")]), critical=False, ).sign(key, hashes.SHA256()) ssl_dir = os.path.join(OUTPUT_DIR, ".ssl") os.makedirs(ssl_dir, exist_ok=True) ssl_keyfile = os.path.join(ssl_dir, "key.pem") ssl_certfile = os.path.join(ssl_dir, "cert.pem") with open(ssl_keyfile, "wb") as f: f.write(key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), )) with open(ssl_certfile, "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) log.info(f"🔒 Generated temporary self-signed SSL certificate in {ssl_dir}") except ImportError: log.warning("⚠️ cryptography package not found. Cannot auto-generate self-signed SSL cert.") log.warning(" Please install it: pip install cryptography") log.warning(" Falling back to standard HTTP.") args.ssl = False import uvicorn uvicorn.run( "cli.api:app", host=args.host, port=args.port, ssl_keyfile=ssl_keyfile if args.ssl or (args.ssl_keyfile and args.ssl_keyfile) else None, ssl_certfile=ssl_certfile if args.ssl or (args.ssl_keyfile and args.ssl_keyfile) else None, )