Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import torch | |
| import os | |
| import random | |
| from typing import Optional, List | |
| import uvicorn | |
| from pydantic import BaseModel | |
| import io | |
| import base64 | |
| from datetime import datetime | |
| from diffusers import AutoencoderKL | |
| from transformers import AutoTokenizer | |
| from OmniGen import OmniGen, OmniGenProcessor, OmniGenPipeline | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="OmniGen API", | |
| description="REST API for OmniGen: Unified Image Generation", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Check for MPS availability | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Initialize components | |
| model_path = "Shitao/OmniGen-v1" | |
| print("Loading model components...") | |
| model = OmniGen.from_pretrained(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") | |
| processor = OmniGenProcessor(tokenizer) | |
| # Create pipeline | |
| pipe = OmniGenPipeline( | |
| vae=vae, | |
| model=model, | |
| processor=processor, | |
| device=device | |
| ) | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| height: Optional[int] = 1024 | |
| width: Optional[int] = 1024 | |
| guidance_scale: Optional[float] = 2.5 | |
| img_guidance_scale: Optional[float] = 1.6 | |
| inference_steps: Optional[int] = 50 | |
| seed: Optional[int] = None | |
| separate_cfg_infer: Optional[bool] = True | |
| offload_model: Optional[bool] = False | |
| use_input_image_size_as_output: Optional[bool] = False | |
| max_input_image_size: Optional[int] = 1024 | |
| randomize_seed: Optional[bool] = True | |
| save_images: Optional[bool] = False | |
| async def process_image(image: UploadFile) -> Optional[str]: | |
| if image is None: | |
| return None | |
| try: | |
| contents = await image.read() | |
| img = Image.open(io.BytesIO(contents)) | |
| # Save to temporary file | |
| temp_path = f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
| img.save(temp_path) | |
| return temp_path | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") | |
| async def generate_image( | |
| prompt: str = Form(...), | |
| image1: Optional[UploadFile] = File(None), | |
| image2: Optional[UploadFile] = File(None), | |
| image3: Optional[UploadFile] = File(None), | |
| height: int = Form(1024), | |
| width: int = Form(1024), | |
| guidance_scale: float = Form(2.5), | |
| img_guidance_scale: float = Form(1.6), | |
| inference_steps: int = Form(50), | |
| seed: Optional[int] = Form(None), | |
| separate_cfg_infer: bool = Form(True), | |
| offload_model: bool = Form(False), | |
| use_input_image_size_as_output: bool = Form(False), | |
| max_input_image_size: int = Form(1024), | |
| randomize_seed: bool = Form(True), | |
| save_images: bool = Form(False) | |
| ): | |
| try: | |
| # Process input images | |
| input_images = [] | |
| for img in [image1, image2, image3]: | |
| if img is not None: | |
| img_path = await process_image(img) | |
| if img_path: | |
| input_images.append(img_path) | |
| if len(input_images) == 0: | |
| input_images = None | |
| if randomize_seed or seed is None: | |
| seed = random.randint(0, 10000000) | |
| # Enable KV cache only for CUDA | |
| if torch.cuda.is_available(): | |
| use_kv_cache = True | |
| offload_kv_cache = True | |
| else: | |
| use_kv_cache = False | |
| offload_kv_cache = False | |
| # Generate image | |
| output = pipe( | |
| prompt=prompt, | |
| input_images=input_images, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| img_guidance_scale=img_guidance_scale, | |
| num_inference_steps=inference_steps, | |
| separate_cfg_infer=separate_cfg_infer, | |
| use_kv_cache=use_kv_cache, | |
| offload_kv_cache=offload_kv_cache, | |
| offload_model=offload_model, | |
| use_input_image_size_as_output=use_input_image_size_as_output, | |
| seed=seed, | |
| max_input_image_size=max_input_image_size, | |
| ) | |
| img = output[0] | |
| # Save image if requested | |
| if save_images: | |
| os.makedirs('outputs', exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") | |
| output_path = os.path.join('outputs', f'{timestamp}.png') | |
| img.save(output_path) | |
| # Convert image to base64 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| # Clean up temporary files | |
| if input_images: | |
| for img_path in input_images: | |
| if os.path.exists(img_path): | |
| os.remove(img_path) | |
| return { | |
| "status": "success", | |
| "image": img_str, | |
| "seed": seed | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy", "device": str(device)} | |
| if __name__ == "__main__": | |
| uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) |