sethchitty commited on
Commit
ca6bb8b
·
verified ·
1 Parent(s): 1e35146

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import AutoPipelineForText2Image, DDIMScheduler
5
+ from transformers import CLIPVisionModelWithProjection
6
+ from diffusers.utils import load_image
7
+ from PIL import Image
8
+ import os
9
+ import json
10
+ import gc
11
+ import traceback
12
+
13
+ STYLE_MAP = {
14
+ "pixar": [
15
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
16
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
17
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
18
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
19
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
20
+ ]
21
+ }
22
+
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"🚀 Device: {device}, torch_dtype: {torch_dtype}")
26
+
27
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
28
+ "h94/IP-Adapter",
29
+ subfolder="models/image_encoder",
30
+ torch_dtype=torch_dtype,
31
+ )
32
+
33
+ pipeline = AutoPipelineForText2Image.from_pretrained(
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ torch_dtype=torch_dtype,
36
+ image_encoder=image_encoder,
37
+ variant="fp16" if torch.cuda.is_available() else None
38
+ ).to(device)
39
+
40
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
41
+ pipeline.load_ip_adapter(
42
+ "h94/IP-Adapter",
43
+ subfolder="sdxl_models",
44
+ weight_name=[
45
+ "ip-adapter-plus_sdxl_vit-h.safetensors",
46
+ "ip-adapter-plus-face_sdxl_vit-h.safetensors"
47
+ ]
48
+ )
49
+ pipeline.set_ip_adapter_scale([0.7, 0.3])
50
+ pipeline.enable_model_cpu_offload()
51
+ pipeline.enable_vae_tiling()
52
+
53
+ def generate_storybook(data):
54
+ print("📥 Input JSON received:")
55
+ print(json.dumps(data, indent=2))
56
+
57
+ character_image_url = data["character_image_url"]
58
+ style = data["style"]
59
+ scenes = data["scenes"]
60
+
61
+ face_image = load_image(character_image_url)
62
+ style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]
63
+
64
+ images = []
65
+
66
+ for i, prompt in enumerate(scenes):
67
+ print(f"🎬 Generating scene {i+1}: {prompt}")
68
+ try:
69
+ torch.cuda.empty_cache()
70
+ gc.collect()
71
+
72
+ result = pipeline(
73
+ prompt=prompt,
74
+ ip_adapter_image=[style_images, face_image],
75
+ negative_prompt="blurry, bad anatomy, low quality",
76
+ width=448,
77
+ height=672,
78
+ guidance_scale=5.0,
79
+ num_inference_steps=15,
80
+ generator=torch.Generator(device).manual_seed(i + 42)
81
+ )
82
+
83
+ image = result.images[0] if hasattr(result, "images") else result
84
+ print(f"🖼️ Image type: {type(image)}")
85
+
86
+ if isinstance(image, Image.Image):
87
+ images.append(image)
88
+ print(f"✅ Scene {i+1} added to image list.")
89
+ else:
90
+ print(f"⚠️ Scene {i+1} is not a valid image object.")
91
+
92
+ except Exception as e:
93
+ print(f"❌ Exception during scene {i+1}: {e}")
94
+ traceback.print_exc()
95
+
96
+ print(f"📦 Returning {len(images)} image(s)")
97
+ return images
98
+
99
+ def generate_storybook_from_textbox(json_input_text):
100
+ try:
101
+ data = json.loads(json_input_text)
102
+ return generate_storybook(data)
103
+ except Exception as e:
104
+ print(f"❌ JSON parse or generation error: {e}")
105
+ traceback.print_exc()
106
+ return [f"Error: {str(e)}"]
107
+
108
+ iface = gr.Interface(
109
+ fn=generate_storybook_from_textbox,
110
+ inputs=gr.Textbox(label="Input JSON", lines=20, placeholder="{...}"),
111
+ outputs=gr.Gallery(label="Generated Story Scenes", show_label=True, columns=1),
112
+ title="AI Storybook Generator (Low VRAM Mode)",
113
+ description="Optimized for lower VRAM GPUs. Paste JSON to generate consistent scenes."
114
+ )
115
+
116
+ iface.launch()