GheeButter commited on
Commit
0fa4a31
·
1 Parent(s): 0704ca2
Files changed (2) hide show
  1. app.py +52 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
4
 
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
  import torch
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "Tongyi-MAI/Z-Image-Turbo"
@@ -20,6 +22,34 @@ pipe = pipe.to(device)
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @spaces.GPU
25
  def infer(
@@ -31,8 +61,14 @@ def infer(
31
  height,
32
  guidance_scale,
33
  num_inference_steps,
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
@@ -48,7 +84,7 @@ def infer(
48
  generator=generator,
49
  ).images[0]
50
 
51
- return image, seed
52
 
53
 
54
  examples = [
@@ -79,6 +115,19 @@ with gr.Blocks(css=css) as demo:
79
 
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  result = gr.Image(label="Result", show_label=False)
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
@@ -146,8 +195,9 @@ with gr.Blocks(css=css) as demo:
146
  height,
147
  guidance_scale,
148
  num_inference_steps,
 
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import os
5
 
6
  import spaces
7
  from diffusers import DiffusionPipeline
8
  import torch
9
+ from huggingface_hub import InferenceClient
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model_repo_id = "Tongyi-MAI/Z-Image-Turbo"
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 1024
24
 
25
+ # Initialize LLM for prompt enhancement
26
+ llm_client = InferenceClient()
27
+
28
+ def enhance_prompt(prompt: str) -> str:
29
+ """Enhance the prompt using an LLM to make it more descriptive for image generation."""
30
+ try:
31
+ system_message = """You are an expert at crafting detailed prompts for text-to-image models.
32
+ Given a simple prompt, enhance it by adding relevant details about style, lighting, composition, and quality.
33
+ Keep the core concept but make it more descriptive. Return only the enhanced prompt, nothing else."""
34
+
35
+ messages = [
36
+ {"role": "system", "content": system_message},
37
+ {"role": "user", "content": f"Enhance this prompt for image generation: {prompt}"}
38
+ ]
39
+
40
+ response = llm_client.chat_completion(
41
+ messages=messages,
42
+ model="meta-llama/Llama-3.3-70B-Instruct",
43
+ max_tokens=200,
44
+ temperature=0.7,
45
+ )
46
+
47
+ enhanced = response.choices[0].message.content.strip()
48
+ return enhanced
49
+ except Exception as e:
50
+ print(f"Error enhancing prompt: {e}")
51
+ return prompt # Return original if enhancement fails
52
+
53
 
54
  @spaces.GPU
55
  def infer(
 
61
  height,
62
  guidance_scale,
63
  num_inference_steps,
64
+ use_prompt_enhancement,
65
  progress=gr.Progress(track_tqdm=True),
66
  ):
67
+ # Enhance prompt if requested
68
+ original_prompt = prompt
69
+ if use_prompt_enhancement:
70
+ prompt = enhance_prompt(prompt)
71
+
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
74
 
 
84
  generator=generator,
85
  ).images[0]
86
 
87
+ return image, seed, prompt
88
 
89
 
90
  examples = [
 
115
 
116
  run_button = gr.Button("Run", scale=0, variant="primary")
117
 
118
+ use_prompt_enhancement = gr.Checkbox(
119
+ label="✨ Enhance prompt with AI",
120
+ value=False,
121
+ info="Use an LLM to make your prompt more detailed"
122
+ )
123
+
124
+ enhanced_prompt_display = gr.Textbox(
125
+ label="Enhanced Prompt",
126
+ interactive=False,
127
+ visible=True,
128
+ lines=2
129
+ )
130
+
131
  result = gr.Image(label="Result", show_label=False)
132
 
133
  with gr.Accordion("Advanced Settings", open=False):
 
195
  height,
196
  guidance_scale,
197
  num_inference_steps,
198
+ use_prompt_enhancement,
199
  ],
200
+ outputs=[result, seed, enhanced_prompt_display],
201
  )
202
 
203
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -3,4 +3,5 @@ diffusers
3
  invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
 
3
  invisible_watermark
4
  torch
5
  transformers
6
+ xformers
7
+ huggingface_hub