Baggio200cn's picture
Upload app.py
2b9a578 verified
import os
import gradio as gr
from huggingface_hub import snapshot_download
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
# Install dependencies
os.system("pip install huggingface-hub transformers diffusers accelerate safetensors torch gradio")
# Download PosterCraft model
model_dir = "models/postercraft/postercraft_rl"
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="PosterCraft/PosterCraft-v1_RL", local_dir=model_dir, resume_download=True, token=hf_token)
# Load model
pipeline_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
pipe = FluxPipeline.from_pretrained(pipeline_id, torch_dtype=dtype, token=hf_token)
pipe.transformer = FluxTransformer2DModel.from_pretrained(model_dir, torch_dtype=dtype, token=hf_token)
pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# Generate function
def generate_poster(prompt):
try:
image = pipe(prompt).images[0]
return image
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI
iface = gr.Interface(
fn=generate_poster,
inputs="text",
outputs="image",
title="Test Poster Generator"
)
iface.launch()