CryptoCreeper commited on
Commit
56151a0
·
verified ·
1 Parent(s): ccb52fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -1,68 +1,59 @@
1
  import gradio as gr
2
  import torch
3
  import re
 
4
  from diffusers import DiffusionPipeline
5
  from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
6
 
7
- # 1. Setup Device
8
  device = "cpu"
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
 
12
- # 2. Load Prompt Enhancer (The Brain)
13
  prompt_enhancer_id = "succinctly/text2image-prompt-generator"
14
  enhancer_tokenizer = GPT2Tokenizer.from_pretrained(prompt_enhancer_id)
15
  enhancer_model = GPT2LMHeadModel.from_pretrained(prompt_enhancer_id)
16
  enhancer_pipe = pipeline("text-generation", model=enhancer_model, tokenizer=enhancer_tokenizer, device=device)
17
 
18
- # 3. Load Image Generator (The Artist) - CPU Optimized
19
  image_model_id = "SimianLuo/LCM_Dreamshaper_v7"
20
  image_pipe = DiffusionPipeline.from_pretrained(image_model_id)
21
  image_pipe.to(device)
22
 
23
  def clean_and_format_prompt(generated_text, original_prompt):
24
- # List of "filler" words to remove
25
  bad_words = [
26
  "4k", "8k", "high definition", "high res", "high resolution",
27
  "hd", "ultra detailed", "masterpiece", "photorealistic",
28
  "best quality", "vray", "unreal engine", "octane render"
29
  ]
30
 
31
- # Clean the generated text
32
  cleaned = generated_text
 
 
 
 
33
  for word in bad_words:
34
- # Remove the word (case insensitive)
35
  cleaned = re.sub(r'\b' + word + r'\b', "", cleaned, flags=re.IGNORECASE)
36
 
37
- # Remove extra commas and spaces created by removal
38
  cleaned = re.sub(r',\s*,', ',', cleaned)
39
  cleaned = re.sub(r'\s+', ' ', cleaned).strip().strip(',')
40
 
41
- # Logic: If the enhancer didn't add much substance, use a template
42
- # This prevents "Apple" -> "Apple 4k 8k" (which becomes just "Apple" after cleaning)
43
- if len(cleaned) < len(original_prompt) + 10:
44
  cleaned = f"{original_prompt}, detailed, centered in frame"
45
 
46
  return cleaned
47
 
48
  def generate_workflow(prompt, width, height, steps):
49
- # Step 1: Analysis
50
- yield "🔍 Thinking (Improving your prompt)...", None, ""
51
 
52
- # Generate extension
53
  try:
54
- # We limit max_length to keep it concise
55
- enhanced_results = enhancer_pipe(prompt, max_length=60, num_return_sequences=1)
56
  raw_output = enhanced_results[0]['generated_text']
57
-
58
- # Apply our cleaning logic
59
  final_prompt = clean_and_format_prompt(raw_output, prompt)
60
- except Exception as e:
61
- # Fallback if enhancer fails
62
  final_prompt = f"{prompt}, detailed, centered in frame"
63
 
64
- # Step 2: Generation
65
- yield "🎨 Generating (Drawing the image)...", None, final_prompt
66
 
67
  image = image_pipe(
68
  prompt=final_prompt,
@@ -74,10 +65,11 @@ def generate_workflow(prompt, width, height, steps):
74
  output_type="pil"
75
  ).images[0]
76
 
77
- # Step 3: Finish
78
- yield "✅ Ready", image, final_prompt
 
 
79
 
80
- # UI Setup - Title and Theme belong here!
81
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Lab") as demo:
82
  gr.Markdown("# 🎨 AI Image Lab")
83
 
@@ -93,7 +85,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Lab") as demo:
93
  width_slider = gr.Slider(256, 768, 512, step=64, label="📏 Width")
94
  height_slider = gr.Slider(256, 768, 512, step=64, label="📐 Height")
95
 
96
- steps_slider = gr.Slider(1, 15, 4, step=1, label="🏃 Steps")
97
 
98
  generate_btn = gr.Button("🚀 Generate", variant="primary")
99
 
 
1
  import gradio as gr
2
  import torch
3
  import re
4
+ import time
5
  from diffusers import DiffusionPipeline
6
  from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
7
 
 
8
  device = "cpu"
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
 
 
12
  prompt_enhancer_id = "succinctly/text2image-prompt-generator"
13
  enhancer_tokenizer = GPT2Tokenizer.from_pretrained(prompt_enhancer_id)
14
  enhancer_model = GPT2LMHeadModel.from_pretrained(prompt_enhancer_id)
15
  enhancer_pipe = pipeline("text-generation", model=enhancer_model, tokenizer=enhancer_tokenizer, device=device)
16
 
 
17
  image_model_id = "SimianLuo/LCM_Dreamshaper_v7"
18
  image_pipe = DiffusionPipeline.from_pretrained(image_model_id)
19
  image_pipe.to(device)
20
 
21
  def clean_and_format_prompt(generated_text, original_prompt):
 
22
  bad_words = [
23
  "4k", "8k", "high definition", "high res", "high resolution",
24
  "hd", "ultra detailed", "masterpiece", "photorealistic",
25
  "best quality", "vray", "unreal engine", "octane render"
26
  ]
27
 
 
28
  cleaned = generated_text
29
+ instruction_trigger = "Enhanced prompt:"
30
+ if instruction_trigger in cleaned:
31
+ cleaned = cleaned.split(instruction_trigger)[-1]
32
+
33
  for word in bad_words:
 
34
  cleaned = re.sub(r'\b' + word + r'\b', "", cleaned, flags=re.IGNORECASE)
35
 
 
36
  cleaned = re.sub(r',\s*,', ',', cleaned)
37
  cleaned = re.sub(r'\s+', ' ', cleaned).strip().strip(',')
38
 
39
+ if len(cleaned) < 5:
 
 
40
  cleaned = f"{original_prompt}, detailed, centered in frame"
41
 
42
  return cleaned
43
 
44
  def generate_workflow(prompt, width, height, steps):
45
+ start_time = time.time()
46
+ yield "🔍 Thinking (analysing AI)...", None, ""
47
 
 
48
  try:
49
+ instructional_prompt = f"Enhance the user prompt so it is suitable for an image generator, and focus on the object, not on the quality, resolution etc. User prompt: {prompt}. Enhanced prompt:"
50
+ enhanced_results = enhancer_pipe(instructional_prompt, max_new_tokens=40, num_return_sequences=1)
51
  raw_output = enhanced_results[0]['generated_text']
 
 
52
  final_prompt = clean_and_format_prompt(raw_output, prompt)
53
+ except:
 
54
  final_prompt = f"{prompt}, detailed, centered in frame"
55
 
56
+ yield "🎨 Generating (Image generator AI)...", None, final_prompt
 
57
 
58
  image = image_pipe(
59
  prompt=final_prompt,
 
65
  output_type="pil"
66
  ).images[0]
67
 
68
+ end_time = time.time()
69
+ duration = round(end_time - start_time, 2)
70
+
71
+ yield f"✅ Done in {duration}s", image, final_prompt
72
 
 
73
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Lab") as demo:
74
  gr.Markdown("# 🎨 AI Image Lab")
75
 
 
85
  width_slider = gr.Slider(256, 768, 512, step=64, label="📏 Width")
86
  height_slider = gr.Slider(256, 768, 512, step=64, label="📐 Height")
87
 
88
+ steps_slider = gr.Slider(4, 12, 5, step=1, label="🏃 Steps")
89
 
90
  generate_btn = gr.Button("🚀 Generate", variant="primary")
91