gaur3009 commited on
Commit
06df74c
·
verified ·
1 Parent(s): 5dea66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -59
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
- from PIL import Image, ImageOps, ImageDraw, ImageFont
4
  import io, os, base64, requests, random
5
 
6
- # HF Client for Z-Image model
7
  client = Client("Tongyi-MAI/Z-Image-Turbo")
8
 
9
- # ---------------------------------------------
10
- # Core function to call HF Space
11
- # ---------------------------------------------
12
- def call_generate(prompt, resolution="1024x1536 ( 2:3 )", seed=None, steps=8, gallery_urls=None, random_seed=False):
13
  gallery_images = []
14
- if gallery_urls:
15
- for url in gallery_urls:
16
- gallery_images.append({"image": handle_file(url), "caption": None})
17
 
18
  payload = dict(
19
  prompt=prompt,
@@ -25,104 +22,93 @@ def call_generate(prompt, resolution="1024x1536 ( 2:3 )", seed=None, steps=8, ga
25
  gallery_images=gallery_images,
26
  api_name="/generate",
27
  )
28
-
29
  return client.predict(**payload)
30
 
31
- # ---------------------------------------------
32
- # Generate posters & overlay tagline + logo
33
- # ---------------------------------------------
34
  def generate_posters(product_name, tagline, product_img, logo_img, variants):
35
 
36
- # Save uploaded files locally so we can pass URLs to the HF API
37
- product_path = "product_temp.png"
38
- logo_path = "logo_temp.png"
 
 
 
39
 
40
  Image.open(product_img).save(product_path)
41
  Image.open(logo_img).save(logo_path)
42
 
43
- product_url = "file://" + os.path.abspath(product_path)
44
- logo_url = "file://" + os.path.abspath(logo_path)
45
-
46
  base_prompt = (
47
- f"{product_name} on a sandy beach at sunset composition: product foreground centered, "
48
- f"low-angle shot, golden hour rim light, cinematic mood, glossy finish, photorealistic high detail."
49
  )
50
 
51
  output_images = []
52
 
53
  for i in range(variants):
54
  seed = random.randint(1, 10**6)
55
- prompt = base_prompt + f" Tagline: \"{tagline}\" Variation {i}."
56
 
57
  raw = call_generate(
58
- prompt,
59
  resolution="1024x1536 ( 2:3 )",
60
  seed=seed,
61
  steps=12,
62
- gallery_urls=[product_url, logo_url],
63
  )
64
 
65
- # Try to decode base64 or URL response
66
  img_bytes = None
67
  if isinstance(raw, dict):
68
  for v in raw.values():
69
  if isinstance(v, str) and v.startswith("data:"):
70
- img_bytes = base64.b64decode(v.split(",")[1])
71
  break
72
  if isinstance(v, list):
73
  for item in v:
74
  if isinstance(item, str) and item.startswith("data:"):
75
- img_bytes = base64.b64decode(item.split(",")[1])
76
  break
77
 
78
- if img_bytes is None and isinstance(raw, str) and raw.startswith("data:"):
79
- img_bytes = base64.b64decode(raw.split(",", 1)[1])
80
-
81
- if img_bytes is None and isinstance(raw, str) and raw.startswith("http"):
82
- img_bytes = requests.get(raw).content
83
-
84
  if img_bytes is None:
85
- return [f"Error decoding image: {raw}"]
86
 
87
  img = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
88
 
89
- # ---- Overlay Logo ---
90
  try:
91
- logo = Image.open(logo_img).convert("RGBA")
92
- logo.thumbnail((int(img.width * 0.18), int(img.height * 0.08)), Image.LANCZOS)
93
- logo_pos = (img.width - logo.width - 40, img.height - logo.height - 40)
94
- img.alpha_composite(logo, logo_pos)
95
  except:
96
  pass
97
 
98
- # ---- Overlay Tagline ---
99
  draw = ImageDraw.Draw(img)
100
  try:
101
- font = ImageFont.truetype("Arial.ttf", size=int(img.width*0.06))
102
  except:
103
  font = ImageFont.load_default()
104
 
105
- text_w, text_h = draw.textsize(tagline, font=font)
106
- x, y = 60, img.height - text_h - 60
107
- draw.rectangle([x-15, y-15, x+text_w+15, y+text_h+15], fill=(0,0,0,160))
108
- draw.text((x,y), tagline, font=font, fill=(255,255,255,255))
109
 
110
- final_path = f"poster_{seed}_{i}.png"
111
- img.convert("RGB").save(final_path)
112
- output_images.append(final_path)
113
 
114
  return output_images
115
 
116
- # ---------------------------------------------
 
117
  # GRADIO UI
118
- # ---------------------------------------------
119
  with gr.Blocks(title="Rookus Poster Generator") as demo:
120
 
121
- gr.Markdown("## 🖼️ Rookus AI Poster Generator\nCreate cinematic advertising posters automatically.")
122
 
123
- with gr.Row():
124
- product_name = gr.Textbox(label="Product Name", placeholder="Sneakers")
125
- tagline = gr.Textbox(label="Tagline", placeholder="Purchase these beautiful sneakers now")
126
 
127
  with gr.Row():
128
  product_img = gr.Image(type="filepath", label="Product Image")
@@ -130,14 +116,13 @@ with gr.Blocks(title="Rookus Poster Generator") as demo:
130
 
131
  variants = gr.Slider(1, 6, value=3, step=1, label="Number of Variants")
132
 
133
- generate_btn = gr.Button("Generate Posters 🎨")
134
-
135
- gallery = gr.Gallery(label="Generated Posters", height="auto")
136
 
137
  generate_btn.click(
138
- fn=generate_posters,
139
  inputs=[product_name, tagline, product_img, logo_img, variants],
140
- outputs=[gallery]
141
  )
142
 
143
  demo.launch()
 
1
  import gradio as gr
2
  from gradio_client import Client, handle_file
3
+ from PIL import Image, ImageDraw, ImageFont
4
  import io, os, base64, requests, random
5
 
 
6
  client = Client("Tongyi-MAI/Z-Image-Turbo")
7
 
8
+ def call_generate(prompt, resolution="1024x1536 ( 2:3 )", seed=None, steps=8, gallery_files=None, random_seed=False):
9
+
 
 
10
  gallery_images = []
11
+ if gallery_files:
12
+ for path in gallery_files:
13
+ gallery_images.append({"image": handle_file(path), "caption": None})
14
 
15
  payload = dict(
16
  prompt=prompt,
 
22
  gallery_images=gallery_images,
23
  api_name="/generate",
24
  )
 
25
  return client.predict(**payload)
26
 
 
 
 
27
  def generate_posters(product_name, tagline, product_img, logo_img, variants):
28
 
29
+ if product_img is None or logo_img is None:
30
+ return ["❌ Please upload both product image and logo."]
31
+
32
+ # Save uploaded files
33
+ product_path = "product.png"
34
+ logo_path = "logo.png"
35
 
36
  Image.open(product_img).save(product_path)
37
  Image.open(logo_img).save(logo_path)
38
 
 
 
 
39
  base_prompt = (
40
+ f"{product_name} on a sandy beach at sunset, cinematic composition, golden rim light, "
41
+ f"glossy highlights, photorealistic 4K detail. Clean whitespace for text."
42
  )
43
 
44
  output_images = []
45
 
46
  for i in range(variants):
47
  seed = random.randint(1, 10**6)
48
+ full_prompt = base_prompt + f" Tagline: \"{tagline}\"."
49
 
50
  raw = call_generate(
51
+ full_prompt,
52
  resolution="1024x1536 ( 2:3 )",
53
  seed=seed,
54
  steps=12,
55
+ gallery_files=[product_path, logo_path]
56
  )
57
 
58
+ # Parse response
59
  img_bytes = None
60
  if isinstance(raw, dict):
61
  for v in raw.values():
62
  if isinstance(v, str) and v.startswith("data:"):
63
+ img_bytes = base64.b64decode(v.split(",", 1)[1])
64
  break
65
  if isinstance(v, list):
66
  for item in v:
67
  if isinstance(item, str) and item.startswith("data:"):
68
+ img_bytes = base64.b64decode(item.split(",", 1)[1])
69
  break
70
 
 
 
 
 
 
 
71
  if img_bytes is None:
72
+ return [" Could not decode model output"]
73
 
74
  img = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
75
 
76
+ # Overlay Logo
77
  try:
78
+ logo = Image.open(logo_path).convert("RGBA")
79
+ logo.thumbnail((int(img.width * 0.18), int(img.height * 0.08)))
80
+ img.alpha_composite(logo, (img.width - logo.width - 30, img.height - logo.height - 30))
 
81
  except:
82
  pass
83
 
84
+ # Overlay Tagline
85
  draw = ImageDraw.Draw(img)
86
  try:
87
+ font = ImageFont.truetype("Arial.ttf", size=int(img.width * 0.06))
88
  except:
89
  font = ImageFont.load_default()
90
 
91
+ tw, th = draw.textsize(tagline, font=font)
92
+ x, y = 60, img.height - th - 60
93
+ draw.rectangle([x-12, y-12, x+tw+12, y+th+12], fill=(0,0,0,160))
94
+ draw.text((x, y), tagline, fill="white", font=font)
95
 
96
+ out_path = f"poster_{seed}_{i}.png"
97
+ img.convert("RGB").save(out_path)
98
+ output_images.append(out_path)
99
 
100
  return output_images
101
 
102
+
103
+ # ---------------------------
104
  # GRADIO UI
105
+ # ---------------------------
106
  with gr.Blocks(title="Rookus Poster Generator") as demo:
107
 
108
+ gr.Markdown("## 🎨 Rookus AI Poster Generator")
109
 
110
+ product_name = gr.Textbox(label="Product Name")
111
+ tagline = gr.Textbox(label="Tagline")
 
112
 
113
  with gr.Row():
114
  product_img = gr.Image(type="filepath", label="Product Image")
 
116
 
117
  variants = gr.Slider(1, 6, value=3, step=1, label="Number of Variants")
118
 
119
+ generate_btn = gr.Button("Generate Posters")
120
+ gallery = gr.Gallery(label="Output Posters")
 
121
 
122
  generate_btn.click(
123
+ generate_posters,
124
  inputs=[product_name, tagline, product_img, logo_img, variants],
125
+ outputs=gallery
126
  )
127
 
128
  demo.launch()