imagemerger / api.py
ashishninehertz's picture
first commit
33c85bd
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import os
import random
from typing import Optional, List
import uvicorn
from pydantic import BaseModel
import io
import base64
from datetime import datetime
from diffusers import AutoencoderKL
from transformers import AutoTokenizer
from OmniGen import OmniGen, OmniGenProcessor, OmniGenPipeline
# Initialize FastAPI app
app = FastAPI(
title="OmniGen API",
description="REST API for OmniGen: Unified Image Generation",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize components
model_path = "Shitao/OmniGen-v1"
print("Loading model components...")
model = OmniGen.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
processor = OmniGenProcessor(tokenizer)
# Create pipeline
pipe = OmniGenPipeline(
vae=vae,
model=model,
processor=processor,
device=device
)
class GenerationRequest(BaseModel):
prompt: str
height: Optional[int] = 1024
width: Optional[int] = 1024
guidance_scale: Optional[float] = 2.5
img_guidance_scale: Optional[float] = 1.6
inference_steps: Optional[int] = 50
seed: Optional[int] = None
separate_cfg_infer: Optional[bool] = True
offload_model: Optional[bool] = False
use_input_image_size_as_output: Optional[bool] = False
max_input_image_size: Optional[int] = 1024
randomize_seed: Optional[bool] = True
save_images: Optional[bool] = False
async def process_image(image: UploadFile) -> Optional[str]:
if image is None:
return None
try:
contents = await image.read()
img = Image.open(io.BytesIO(contents))
# Save to temporary file
temp_path = f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
img.save(temp_path)
return temp_path
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")
@app.post("/generate")
async def generate_image(
prompt: str = Form(...),
image1: Optional[UploadFile] = File(None),
image2: Optional[UploadFile] = File(None),
image3: Optional[UploadFile] = File(None),
height: int = Form(1024),
width: int = Form(1024),
guidance_scale: float = Form(2.5),
img_guidance_scale: float = Form(1.6),
inference_steps: int = Form(50),
seed: Optional[int] = Form(None),
separate_cfg_infer: bool = Form(True),
offload_model: bool = Form(False),
use_input_image_size_as_output: bool = Form(False),
max_input_image_size: int = Form(1024),
randomize_seed: bool = Form(True),
save_images: bool = Form(False)
):
try:
# Process input images
input_images = []
for img in [image1, image2, image3]:
if img is not None:
img_path = await process_image(img)
if img_path:
input_images.append(img_path)
if len(input_images) == 0:
input_images = None
if randomize_seed or seed is None:
seed = random.randint(0, 10000000)
# Enable KV cache only for CUDA
if torch.cuda.is_available():
use_kv_cache = True
offload_kv_cache = True
else:
use_kv_cache = False
offload_kv_cache = False
# Generate image
output = pipe(
prompt=prompt,
input_images=input_images,
height=height,
width=width,
guidance_scale=guidance_scale,
img_guidance_scale=img_guidance_scale,
num_inference_steps=inference_steps,
separate_cfg_infer=separate_cfg_infer,
use_kv_cache=use_kv_cache,
offload_kv_cache=offload_kv_cache,
offload_model=offload_model,
use_input_image_size_as_output=use_input_image_size_as_output,
seed=seed,
max_input_image_size=max_input_image_size,
)
img = output[0]
# Save image if requested
if save_images:
os.makedirs('outputs', exist_ok=True)
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
output_path = os.path.join('outputs', f'{timestamp}.png')
img.save(output_path)
# Convert image to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Clean up temporary files
if input_images:
for img_path in input_images:
if os.path.exists(img_path):
os.remove(img_path)
return {
"status": "success",
"image": img_str,
"seed": seed
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "device": str(device)}
if __name__ == "__main__":
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)