File size: 4,108 Bytes
755b512
 
 
075c8f0
755b512
 
075c8f0
 
 
 
755b512
075c8f0
755b512
1a974b2
755b512
075c8f0
 
6c5d32a
755b512
8bac254
 
 
 
 
 
6c5d32a
755b512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82a1
 
380deae
22f2658
755b512
44895ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755b512
44895ee
 
 
 
 
 
755b512
 
44895ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755b512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82a1
755b512
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import torch
import spaces
import os
from diffusers import FluxPipeline
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

# 1. 获取你在 Space 设置里配置的 Token
hf_token = os.environ.get("HF_TOKEN")

# 2. Load the model (加入了 token 参数)
pipe = FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-dev',
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    token=hf_token
).to('cuda')

# Load SRPO weights
srpo_path = hf_hub_download(
    repo_id="tencent/SRPO",
    filename="diffusion_pytorch_model.safetensors"
)
state_dict = load_file(srpo_path)
pipe.transformer.load_state_dict(state_dict)

@spaces.GPU(duration=120)
def generate_image(
    prompt,
    width=1024,
    height=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    seed=-1
):
    if seed == -1:
        seed = torch.randint(0, 2**32, (1,)).item()
    
    generator = torch.Generator(device='cuda').manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        guidance_scale=guidance_scale,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        max_sequence_length=512,
        generator=generator
    ).images[0]
    
    return image, seed

with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo:
    gr.Markdown("# Flux SRPO")
    gr.Markdown("Generate images using FLUX model enhanced with Tencent's [SRPO](https://github.com/Tencent-Hunyuan/SRPO) technique")
    gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)")
    
    output_image = gr.Image(label="Generated Image", type="pil")
    
    prompt = gr.Textbox(
        label="Prompt",
        placeholder="Describe the image you want to generate...",
        lines=3
    )
    
    generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
    
    with gr.Accordion("Advanced Settings", open=False):
        with gr.Row():
            width = gr.Slider(
                minimum=256,
                maximum=2048,
                value=1024,
                step=64,
                label="Width"
            )
            height = gr.Slider(
                minimum=256,
                maximum=2048,
                value=1024,
                step=64,
                label="Height"
            )
        
        with gr.Row():
            guidance_scale = gr.Slider(
                minimum=1.0,
                maximum=20.0,
                value=3.5,
                step=0.5,
                label="Guidance Scale"
            )
            num_inference_steps = gr.Slider(
                minimum=10,
                maximum=100,
                value=50,
                step=5,
                label="Inference Steps"
            )
        
        seed = gr.Number(
            label="Seed (-1 for random)",
            value=-1,
            precision=0
        )
        
        used_seed = gr.Number(label="Seed Used", precision=0)
    
    gr.Examples(
        examples=[
            ["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"],
            ["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"],
            ["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"],
            ["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"],
            ["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"],
        ],
        inputs=prompt,
        label="Example Prompts"
    )
    
    generate_btn.click(
        fn=generate_image,
        inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed],
        outputs=[output_image, used_seed]
    )

demo.launch()