from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch import os import random from typing import Optional, List import uvicorn from pydantic import BaseModel import io import base64 from datetime import datetime from diffusers import AutoencoderKL from transformers import AutoTokenizer from OmniGen import OmniGen, OmniGenProcessor, OmniGenPipeline # Initialize FastAPI app app = FastAPI( title="OmniGen API", description="REST API for OmniGen: Unified Image Generation", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Check for MPS availability device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") # Initialize components model_path = "Shitao/OmniGen-v1" print("Loading model components...") model = OmniGen.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") processor = OmniGenProcessor(tokenizer) # Create pipeline pipe = OmniGenPipeline( vae=vae, model=model, processor=processor, device=device ) class GenerationRequest(BaseModel): prompt: str height: Optional[int] = 1024 width: Optional[int] = 1024 guidance_scale: Optional[float] = 2.5 img_guidance_scale: Optional[float] = 1.6 inference_steps: Optional[int] = 50 seed: Optional[int] = None separate_cfg_infer: Optional[bool] = True offload_model: Optional[bool] = False use_input_image_size_as_output: Optional[bool] = False max_input_image_size: Optional[int] = 1024 randomize_seed: Optional[bool] = True save_images: Optional[bool] = False async def process_image(image: UploadFile) -> Optional[str]: if image is None: return None try: contents = await image.read() img = Image.open(io.BytesIO(contents)) # Save to temporary file temp_path = f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" img.save(temp_path) return temp_path except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") @app.post("/generate") async def generate_image( prompt: str = Form(...), image1: Optional[UploadFile] = File(None), image2: Optional[UploadFile] = File(None), image3: Optional[UploadFile] = File(None), height: int = Form(1024), width: int = Form(1024), guidance_scale: float = Form(2.5), img_guidance_scale: float = Form(1.6), inference_steps: int = Form(50), seed: Optional[int] = Form(None), separate_cfg_infer: bool = Form(True), offload_model: bool = Form(False), use_input_image_size_as_output: bool = Form(False), max_input_image_size: int = Form(1024), randomize_seed: bool = Form(True), save_images: bool = Form(False) ): try: # Process input images input_images = [] for img in [image1, image2, image3]: if img is not None: img_path = await process_image(img) if img_path: input_images.append(img_path) if len(input_images) == 0: input_images = None if randomize_seed or seed is None: seed = random.randint(0, 10000000) # Enable KV cache only for CUDA if torch.cuda.is_available(): use_kv_cache = True offload_kv_cache = True else: use_kv_cache = False offload_kv_cache = False # Generate image output = pipe( prompt=prompt, input_images=input_images, height=height, width=width, guidance_scale=guidance_scale, img_guidance_scale=img_guidance_scale, num_inference_steps=inference_steps, separate_cfg_infer=separate_cfg_infer, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache, offload_model=offload_model, use_input_image_size_as_output=use_input_image_size_as_output, seed=seed, max_input_image_size=max_input_image_size, ) img = output[0] # Save image if requested if save_images: os.makedirs('outputs', exist_ok=True) timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") output_path = os.path.join('outputs', f'{timestamp}.png') img.save(output_path) # Convert image to base64 buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() # Clean up temporary files if input_images: for img_path in input_images: if os.path.exists(img_path): os.remove(img_path) return { "status": "success", "image": img_str, "seed": seed } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy", "device": str(device)} if __name__ == "__main__": uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)