pixelsdesign commited on
Commit
62641a0
·
verified ·
1 Parent(s): 2ee0042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -2,28 +2,51 @@ import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
- # 🔹 Load PosterCraft model from Hugging Face Hub
6
- pipe = StableDiffusionPipeline.from_pretrained(
7
- "PosterCraft/PosterCraft-v1_RL", # <-- PosterCraft model
8
- torch_dtype=torch.float16
9
- )
 
 
 
 
 
 
 
 
 
10
 
11
- # Use GPU if available for faster generation
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  pipe.to(device)
14
 
15
- def generate_poster(prompt):
16
- # Generate the poster image
 
 
 
 
 
 
 
17
  image = pipe(prompt).images[0]
18
  return image
19
 
 
20
  # Gradio UI
 
21
  demo = gr.Interface(
22
  fn=generate_poster,
23
- inputs=gr.Textbox(label="Enter your poster prompt"),
 
 
 
24
  outputs=gr.Image(type="pil", label="Generated Poster"),
25
  title="AI Poster Generator",
26
- description="Type something like: 'modern event flyer with bold typography and sunset gradient background'"
27
  )
28
 
29
- demo.launch()
 
 
 
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
+ # ----------------------------------------
6
+ # Load PosterCraft model from Hugging Face
7
+ # ----------------------------------------
8
+ # If your Space uses a free CPU runtime, you can leave torch_dtype=None.
9
+ # On GPU, float16 is faster and uses less memory.
10
+ model_id = "PosterCraft/PosterCraft-v1_RL"
11
+
12
+ try:
13
+ pipe = StableDiffusionPipeline.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else None
16
+ )
17
+ except Exception as e:
18
+ raise RuntimeError(f"Error loading model {model_id}: {e}")
19
 
20
+ # Select device
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  pipe.to(device)
23
 
24
+ # ----------------------------------------
25
+ # Generation function
26
+ # ----------------------------------------
27
+ def generate_poster(prompt: str):
28
+ """
29
+ Generate a poster/flyer based on the user's text prompt.
30
+ """
31
+ if not prompt.strip():
32
+ return None
33
  image = pipe(prompt).images[0]
34
  return image
35
 
36
+ # ----------------------------------------
37
  # Gradio UI
38
+ # ----------------------------------------
39
  demo = gr.Interface(
40
  fn=generate_poster,
41
+ inputs=gr.Textbox(
42
+ label="Enter your poster prompt",
43
+ placeholder="e.g. Modern music festival flyer with bold typography and neon lights"
44
+ ),
45
  outputs=gr.Image(type="pil", label="Generated Poster"),
46
  title="AI Poster Generator",
47
+ description="Enter a description and get an AI-designed poster or flyer."
48
  )
49
 
50
+ # Launch the app
51
+ if __name__ == "__main__":
52
+ demo.launch()