CryptoCreeper commited on
Commit
ccb52fe
·
verified ·
1 Parent(s): 197d222

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -25
app.py CHANGED
@@ -4,45 +4,68 @@ import re
4
  from diffusers import DiffusionPipeline
5
  from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
6
 
 
7
  device = "cpu"
8
  if torch.cuda.is_available():
9
  device = "cuda"
10
 
 
11
  prompt_enhancer_id = "succinctly/text2image-prompt-generator"
12
  enhancer_tokenizer = GPT2Tokenizer.from_pretrained(prompt_enhancer_id)
13
  enhancer_model = GPT2LMHeadModel.from_pretrained(prompt_enhancer_id)
14
  enhancer_pipe = pipeline("text-generation", model=enhancer_model, tokenizer=enhancer_tokenizer, device=device)
15
 
 
16
  image_model_id = "SimianLuo/LCM_Dreamshaper_v7"
17
  image_pipe = DiffusionPipeline.from_pretrained(image_model_id)
18
  image_pipe.to(device)
19
 
20
- def clean_prompt(text, original_user_input):
21
- forbidden_words = ["4k", "8k", "high res", "high resolution", "hd", "ultra detailed", "masterpiece", "vray", "render"]
22
- for word in forbidden_words:
23
- text = re.sub(r'\b' + word + r'\b', "", text, flags=re.IGNORECASE)
 
 
 
24
 
25
- text = re.sub(r'\s+', ' ', text).strip()
 
 
 
 
26
 
27
- if len(original_user_input.split()) < 4:
28
- text = f"{original_user_input}, realistic look, centered in the image"
29
- elif "centered" not in text.lower():
30
- text += ", centered in the image"
31
-
32
- return text
33
-
34
- def generate_workflow(prompt, width, height, steps):
35
- yield "🔍 Thinking (analysing AI)...", None, ""
36
 
37
- enhanced_results = enhancer_pipe(prompt, max_length=60, num_return_sequences=1)
38
- raw_enhanced = enhanced_results[0]['generated_text']
 
 
39
 
40
- refined_prompt = clean_prompt(raw_enhanced, prompt)
 
 
 
 
41
 
42
- yield "🎨 Generating (Image generator AI)...", None, refined_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  image = image_pipe(
45
- prompt=refined_prompt,
46
  width=int(width),
47
  height=int(height),
48
  num_inference_steps=int(steps),
@@ -51,16 +74,18 @@ def generate_workflow(prompt, width, height, steps):
51
  output_type="pil"
52
  ).images[0]
53
 
54
- yield "✅ Ready", image, refined_prompt
 
55
 
56
- with gr.Blocks() as demo:
 
57
  gr.Markdown("# 🎨 AI Image Lab")
58
 
59
  with gr.Row():
60
  with gr.Column(scale=1):
61
  prompt_input = gr.Textbox(
62
  label="💡 Your Idea",
63
- placeholder="e.g., Apple fruit",
64
  lines=3
65
  )
66
 
@@ -80,8 +105,7 @@ with gr.Blocks() as demo:
80
  generate_btn.click(
81
  fn=generate_workflow,
82
  inputs=[prompt_input, width_slider, height_slider, steps_slider],
83
- outputs=[status_bar, image_output, refined_prompt_display],
84
- api_name="predict"
85
  )
86
 
87
- demo.launch(theme=gr.themes.Soft(), title="AI Image Lab")
 
4
  from diffusers import DiffusionPipeline
5
  from transformers import pipeline, GPT2Tokenizer, GPT2LMHeadModel
6
 
7
+ # 1. Setup Device
8
  device = "cpu"
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
 
12
+ # 2. Load Prompt Enhancer (The Brain)
13
  prompt_enhancer_id = "succinctly/text2image-prompt-generator"
14
  enhancer_tokenizer = GPT2Tokenizer.from_pretrained(prompt_enhancer_id)
15
  enhancer_model = GPT2LMHeadModel.from_pretrained(prompt_enhancer_id)
16
  enhancer_pipe = pipeline("text-generation", model=enhancer_model, tokenizer=enhancer_tokenizer, device=device)
17
 
18
+ # 3. Load Image Generator (The Artist) - CPU Optimized
19
  image_model_id = "SimianLuo/LCM_Dreamshaper_v7"
20
  image_pipe = DiffusionPipeline.from_pretrained(image_model_id)
21
  image_pipe.to(device)
22
 
23
+ def clean_and_format_prompt(generated_text, original_prompt):
24
+ # List of "filler" words to remove
25
+ bad_words = [
26
+ "4k", "8k", "high definition", "high res", "high resolution",
27
+ "hd", "ultra detailed", "masterpiece", "photorealistic",
28
+ "best quality", "vray", "unreal engine", "octane render"
29
+ ]
30
 
31
+ # Clean the generated text
32
+ cleaned = generated_text
33
+ for word in bad_words:
34
+ # Remove the word (case insensitive)
35
+ cleaned = re.sub(r'\b' + word + r'\b', "", cleaned, flags=re.IGNORECASE)
36
 
37
+ # Remove extra commas and spaces created by removal
38
+ cleaned = re.sub(r',\s*,', ',', cleaned)
39
+ cleaned = re.sub(r'\s+', ' ', cleaned).strip().strip(',')
 
 
 
 
 
 
40
 
41
+ # Logic: If the enhancer didn't add much substance, use a template
42
+ # This prevents "Apple" -> "Apple 4k 8k" (which becomes just "Apple" after cleaning)
43
+ if len(cleaned) < len(original_prompt) + 10:
44
+ cleaned = f"{original_prompt}, detailed, centered in frame"
45
 
46
+ return cleaned
47
+
48
+ def generate_workflow(prompt, width, height, steps):
49
+ # Step 1: Analysis
50
+ yield "🔍 Thinking (Improving your prompt)...", None, ""
51
 
52
+ # Generate extension
53
+ try:
54
+ # We limit max_length to keep it concise
55
+ enhanced_results = enhancer_pipe(prompt, max_length=60, num_return_sequences=1)
56
+ raw_output = enhanced_results[0]['generated_text']
57
+
58
+ # Apply our cleaning logic
59
+ final_prompt = clean_and_format_prompt(raw_output, prompt)
60
+ except Exception as e:
61
+ # Fallback if enhancer fails
62
+ final_prompt = f"{prompt}, detailed, centered in frame"
63
+
64
+ # Step 2: Generation
65
+ yield "🎨 Generating (Drawing the image)...", None, final_prompt
66
 
67
  image = image_pipe(
68
+ prompt=final_prompt,
69
  width=int(width),
70
  height=int(height),
71
  num_inference_steps=int(steps),
 
74
  output_type="pil"
75
  ).images[0]
76
 
77
+ # Step 3: Finish
78
+ yield "✅ Ready", image, final_prompt
79
 
80
+ # UI Setup - Title and Theme belong here!
81
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Lab") as demo:
82
  gr.Markdown("# 🎨 AI Image Lab")
83
 
84
  with gr.Row():
85
  with gr.Column(scale=1):
86
  prompt_input = gr.Textbox(
87
  label="💡 Your Idea",
88
+ placeholder="e.g., A cute dragon",
89
  lines=3
90
  )
91
 
 
105
  generate_btn.click(
106
  fn=generate_workflow,
107
  inputs=[prompt_input, width_slider, height_slider, steps_slider],
108
+ outputs=[status_bar, image_output, refined_prompt_display]
 
109
  )
110
 
111
+ demo.launch()