| import argparse |
| import base64 |
| import io |
| import time |
| import torch |
| import uvicorn |
| import gc |
| import asyncio |
| import os |
| import sys |
| import os |
| import inspect |
|
|
| |
| |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| project_root = os.path.dirname(current_dir) |
| omnigen_path = os.path.join(project_root, "packages", "OmniGen2") |
| sys.path.insert(0, omnigen_path) |
|
|
| from typing import List, Optional |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form |
| from pydantic import BaseModel |
| from PIL import Image, ImageOps |
|
|
| |
| from omnigen2.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline |
| from omnigen2.models.transformers.transformer_omnigen2 import OmniGen2Transformer2DModel |
| from transformers import CLIPProcessor, BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration |
| from transformers.modeling_utils import no_init_weights |
|
|
| |
| |
|
|
| |
| parser = argparse.ArgumentParser(description="OmniGen2 Image Edit Server") |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") |
| parser.add_argument("--port", type=int, default=8000, help="Port to bind to") |
| |
| parser.add_argument("--base-model", type=str, default="../models/OmniGen2", help="Path to base OmniGen2 model") |
| parser.add_argument("--dtype", type=str, default='bf16', choices=['fp32', 'fp16', 'bf16'], help="Model precision") |
|
|
| args = parser.parse_args() |
|
|
| app = FastAPI() |
|
|
| |
| pipeline = None |
| request_lock = asyncio.Lock() |
|
|
| def load_model(): |
| global pipeline |
| |
| print(f"Loading OmniGen2 from {args.base_model}...") |
|
|
| |
| weight_dtype = torch.float32 |
| if args.dtype == 'fp16': |
| weight_dtype = torch.float16 |
| elif args.dtype == 'bf16': |
| weight_dtype = torch.bfloat16 |
|
|
| try: |
| |
| |
| |
|
|
| |
| print("Loading MLLM in 4-bit mode for extra village efficiency!") |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=weight_dtype, |
| ) |
| mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| args.base_model, |
| subfolder="mllm", |
| quantization_config=quantization_config, |
| torch_dtype=weight_dtype, |
| ) |
|
|
| pipeline = OmniGen2Pipeline.from_pretrained( |
| args.base_model, |
| mllm=mllm, |
| processor=CLIPProcessor.from_pretrained( |
| args.base_model, |
| subfolder="processor", |
| use_fast=True |
| ), |
| torch_dtype=weight_dtype, |
| trust_remote_code=True, |
| ).to("cuda") |
|
|
| pipeline.enable_taylorseer = True |
| pipeline.transformer.set_attention_backend("flash") |
|
|
|
|
| print("Enabling CPU offload...") |
| |
| |
| except Exception as e: |
| print(f"Oh no! The OmniGen2 spirit refused to manifest: {e}") |
| raise e |
| |
| print("OmniGen2 loaded successfully! Let's paint the village!") |
|
|
| def flush(): |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| class ImageGenerationRequest(BaseModel): |
| prompt: str |
| n: int = 1 |
| size: str = "1024x1024" |
| response_format: str = "b64_json" |
| quality: str = "standard" |
| style: str = "vivid" |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| load_model() |
|
|
| @app.post("/v1/images/edits") |
| async def edit_image( |
| image: UploadFile = File(...), |
| prompt: str = Form(...), |
| n: int = Form(1), |
| size: str = Form("1024x1024"), |
| response_format: str = Form("b64_json"), |
| guidance_scale: float = Form(2.5), |
| strength: float = Form(1.0) |
| |
| |
| |
| ): |
| if not pipeline: |
| raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
| async with request_lock: |
| print(f"Received edit request: {prompt}") |
|
|
| |
| try: |
| contents = await image.read() |
| init_image = Image.open(io.BytesIO(contents)).convert("RGB") |
| init_image = ImageOps.exif_transpose(init_image) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid image file: {e}") |
|
|
| |
| try: |
| target_width, target_height = map(int, size.split("x")) |
| except ValueError: |
| target_width, target_height = 1024, 1024 |
|
|
| |
| orig_width, orig_height = init_image.size |
| scale = min(target_width / orig_width, target_height / orig_height) |
| new_width = int(orig_width * scale) |
| new_height = int(orig_height * scale) |
| |
| |
| width = (new_width // 16) * 16 |
| height = (new_height // 16) * 16 |
| |
| response_images = [] |
| |
| try: |
| |
| |
| |
| |
| |
| results = pipeline( |
| prompt=prompt, |
| input_images=[init_image], |
| width=width, |
| height=height, |
| num_inference_steps=26, |
| max_sequence_length=1024, |
| text_guidance_scale=5.0, |
| image_guidance_scale=guidance_scale, |
| cfg_range=(0.0, 1.0), |
| negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar", |
| num_images_per_prompt=n, |
| output_type="pil", |
| ) |
| |
| for img in results.images: |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| response_images.append({"b64_json": img_str}) |
| |
| except Exception as e: |
| print(f"Error during editing: {e}") |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
| finally: |
| flush() |
|
|
| return { |
| "created": int(time.time()), |
| "data": response_images |
| } |
|
|
| @app.post("/v1/images/generations") |
| async def generate_image(request: ImageGenerationRequest): |
| if not pipeline: |
| raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
| async with request_lock: |
| print(f"Received generation request: {request.prompt}") |
|
|
| |
| try: |
| width, height = map(int, request.size.split("x")) |
| except ValueError: |
| width, height = 1024, 1024 |
|
|
| |
| width = (width // 16) * 16 |
| height = (height // 16) * 16 |
|
|
| response_images = [] |
| |
| try: |
| |
| results = pipeline( |
| prompt=request.prompt, |
| input_images=None, |
| width=width, |
| height=height, |
| num_inference_steps=26, |
| max_sequence_length=1024, |
| text_guidance_scale=5.0, |
| image_guidance_scale=2.0, |
| cfg_range=(0.0, 1.0), |
| negative_prompt="(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar", |
| num_images_per_prompt=request.n, |
| output_type="pil", |
| ) |
| |
| for img in results.images: |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| response_images.append({"b64_json": img_str}) |
| |
| except Exception as e: |
| print(f"Error during generation: {e}") |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
| finally: |
| flush() |
|
|
| return { |
| "created": int(time.time()), |
| "data": response_images |
| } |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host=args.host, port=args.port) |
|
|