Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import random | |
| import torch | |
| import boto3 | |
| from io import BytesIO | |
| import time | |
| import os | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import FluxPipeline | |
| # S3 Configuration from environment variables | |
| S3_BUCKET = os.getenv("S3_BUCKET") | |
| S3_REGION = os.getenv("S3_REGION") | |
| S3_ACCESS_KEY_ID = os.getenv("S3_ACCESS_KEY_ID") | |
| S3_SECRET_ACCESS_KEY = os.getenv("S3_SECRET_ACCESS_KEY") | |
| # Validate S3 environment variables | |
| if not all([S3_BUCKET, S3_REGION, S3_ACCESS_KEY_ID, S3_SECRET_ACCESS_KEY]): | |
| raise ValueError("Missing required S3 environment variables") | |
| # Set up S3 client | |
| s3_client = boto3.client('s3', | |
| region_name=S3_REGION, | |
| aws_access_key_id=S3_ACCESS_KEY_ID, | |
| aws_secret_access_key=S3_SECRET_ACCESS_KEY) | |
| # Set up cache path | |
| cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
| os.environ["TRANSFORMERS_CACHE"] = cache_path | |
| os.environ["HF_HUB_CACHE"] = cache_path | |
| os.environ["HF_HOME"] = cache_path | |
| if not os.path.exists(cache_path): | |
| os.makedirs(cache_path, exist_ok=True) | |
| # Set up CUDA and model | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Initialize FluxPipeline | |
| pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) | |
| pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
| pipe.fuse_lora(lora_scale=0.125) | |
| pipe.to(device=device, dtype=torch.bfloat16) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| app = FastAPI() | |
| class InferenceRequest(BaseModel): | |
| prompt: str | |
| seed: int = 42 | |
| randomize_seed: bool = True | |
| width: int = 1024 | |
| height: int = 1024 | |
| guidance_scale: float = 3.5 | |
| num_inference_steps: int = 8 | |
| class Timer: | |
| def __init__(self, method_name="timed process"): | |
| self.method = method_name | |
| def __enter__(self): | |
| self.start = time.time() | |
| print(f"{self.method} starts") | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| end = time.time() | |
| print(f"{self.method} took {str(round(end - self.start, 2))}s") | |
| def save_image_to_s3(image): | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| filename = f"generated_image_{int(time.time())}.png" | |
| s3_client.put_object(Bucket=S3_BUCKET, | |
| Key=filename, | |
| Body=img_byte_arr, | |
| ContentType='image/png') | |
| url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" | |
| return url | |
| def process_image(height, width, steps, scales, prompt, seed): | |
| global pipe | |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), Timer("inference"): | |
| return pipe( | |
| prompt=[prompt], | |
| generator=torch.Generator().manual_seed(int(seed)), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(scales), | |
| height=int(height), | |
| width=int(width), | |
| max_sequence_length=256 | |
| ).images[0] | |
| async def infer(request: InferenceRequest): | |
| if request.randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| else: | |
| seed = request.seed | |
| try: | |
| image = process_image( | |
| height=request.height, | |
| width=request.width, | |
| steps=request.num_inference_steps, | |
| scales=request.guidance_scale, | |
| prompt=request.prompt, | |
| seed=seed | |
| ) | |
| image_url = save_image_to_s3(image) | |
| return {"image_url": image_url, "seed": seed} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| return {"message": "Welcome to the IG API"} |