Spaces:
Runtime error
Runtime error
File size: 5,541 Bytes
33c85bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import os
import random
from typing import Optional, List
import uvicorn
from pydantic import BaseModel
import io
import base64
from datetime import datetime
from diffusers import AutoencoderKL
from transformers import AutoTokenizer
from OmniGen import OmniGen, OmniGenProcessor, OmniGenPipeline
# Initialize FastAPI app
app = FastAPI(
title="OmniGen API",
description="REST API for OmniGen: Unified Image Generation",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize components
model_path = "Shitao/OmniGen-v1"
print("Loading model components...")
model = OmniGen.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
processor = OmniGenProcessor(tokenizer)
# Create pipeline
pipe = OmniGenPipeline(
vae=vae,
model=model,
processor=processor,
device=device
)
class GenerationRequest(BaseModel):
prompt: str
height: Optional[int] = 1024
width: Optional[int] = 1024
guidance_scale: Optional[float] = 2.5
img_guidance_scale: Optional[float] = 1.6
inference_steps: Optional[int] = 50
seed: Optional[int] = None
separate_cfg_infer: Optional[bool] = True
offload_model: Optional[bool] = False
use_input_image_size_as_output: Optional[bool] = False
max_input_image_size: Optional[int] = 1024
randomize_seed: Optional[bool] = True
save_images: Optional[bool] = False
async def process_image(image: UploadFile) -> Optional[str]:
if image is None:
return None
try:
contents = await image.read()
img = Image.open(io.BytesIO(contents))
# Save to temporary file
temp_path = f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
img.save(temp_path)
return temp_path
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")
@app.post("/generate")
async def generate_image(
prompt: str = Form(...),
image1: Optional[UploadFile] = File(None),
image2: Optional[UploadFile] = File(None),
image3: Optional[UploadFile] = File(None),
height: int = Form(1024),
width: int = Form(1024),
guidance_scale: float = Form(2.5),
img_guidance_scale: float = Form(1.6),
inference_steps: int = Form(50),
seed: Optional[int] = Form(None),
separate_cfg_infer: bool = Form(True),
offload_model: bool = Form(False),
use_input_image_size_as_output: bool = Form(False),
max_input_image_size: int = Form(1024),
randomize_seed: bool = Form(True),
save_images: bool = Form(False)
):
try:
# Process input images
input_images = []
for img in [image1, image2, image3]:
if img is not None:
img_path = await process_image(img)
if img_path:
input_images.append(img_path)
if len(input_images) == 0:
input_images = None
if randomize_seed or seed is None:
seed = random.randint(0, 10000000)
# Enable KV cache only for CUDA
if torch.cuda.is_available():
use_kv_cache = True
offload_kv_cache = True
else:
use_kv_cache = False
offload_kv_cache = False
# Generate image
output = pipe(
prompt=prompt,
input_images=input_images,
height=height,
width=width,
guidance_scale=guidance_scale,
img_guidance_scale=img_guidance_scale,
num_inference_steps=inference_steps,
separate_cfg_infer=separate_cfg_infer,
use_kv_cache=use_kv_cache,
offload_kv_cache=offload_kv_cache,
offload_model=offload_model,
use_input_image_size_as_output=use_input_image_size_as_output,
seed=seed,
max_input_image_size=max_input_image_size,
)
img = output[0]
# Save image if requested
if save_images:
os.makedirs('outputs', exist_ok=True)
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
output_path = os.path.join('outputs', f'{timestamp}.png')
img.save(output_path)
# Convert image to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Clean up temporary files
if input_images:
for img_path in input_images:
if os.path.exists(img_path):
os.remove(img_path)
return {
"status": "success",
"image": img_str,
"seed": seed
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "device": str(device)}
if __name__ == "__main__":
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) |