sethchitty commited on
Commit
f870136
·
verified ·
1 Parent(s): 1878ae8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import AutoPipelineForText2Image, DDIMScheduler
5
+ from transformers import CLIPVisionModelWithProjection
6
+ from diffusers.utils import load_image
7
+ from PIL import Image
8
+ import os
9
+ import json
10
+ import gc
11
+ import traceback
12
+
13
+ STYLE_MAP = {
14
+ "pixar": [
15
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
16
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
17
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
18
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
19
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
20
+ ]
21
+ }
22
+
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"🚀 Device: {device}, torch_dtype: {torch_dtype}")
26
+
27
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
28
+ "h94/IP-Adapter",
29
+ subfolder="models/image_encoder",
30
+ torch_dtype=torch_dtype,
31
+ )
32
+
33
+ pipeline = AutoPipelineForText2Image.from_pretrained(
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ torch_dtype=torch_dtype,
36
+ image_encoder=image_encoder,
37
+ variant="fp16" if torch.cuda.is_available() else None
38
+ ).to(device)
39
+
40
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
41
+ pipeline.load_ip_adapter(
42
+ "h94/IP-Adapter",
43
+ subfolder="sdxl_models",
44
+ weight_name=[
45
+ "ip-adapter-plus_sdxl_vit-h.safetensors",
46
+ "ip-adapter-plus-face_sdxl_vit-h.safetensors"
47
+ ]
48
+ )
49
+ pipeline.set_ip_adapter_scale([0.7, 0.3])
50
+ pipeline.enable_model_cpu_offload()
51
+
52
+ def generate_storybook(data):
53
+ print("📥 Input JSON received:")
54
+ print(json.dumps(data, indent=2))
55
+
56
+ character_image_url = data["character_image_url"]
57
+ style = data["style"]
58
+ scenes = data["scenes"]
59
+
60
+ face_image = load_image(character_image_url)
61
+ style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]
62
+
63
+ images = []
64
+
65
+ for i, prompt in enumerate(scenes):
66
+ print(f"🎬 Generating scene {i+1}: {prompt}")
67
+ try:
68
+ torch.cuda.empty_cache()
69
+ gc.collect()
70
+
71
+ result = pipeline(
72
+ prompt=prompt,
73
+ ip_adapter_image=[style_images, face_image],
74
+ negative_prompt="blurry, bad anatomy, low quality",
75
+ width=512,
76
+ height=768,
77
+ guidance_scale=7.5,
78
+ num_inference_steps=20,
79
+ generator=torch.Generator(device).manual_seed(i + 42)
80
+ )
81
+
82
+ # Check whether result is a dict or image
83
+ if hasattr(result, "images"):
84
+ image = result.images[0]
85
+ else:
86
+ image = result
87
+
88
+ print(f"🖼️ Image type: {type(image)}")
89
+
90
+ if isinstance(image, Image.Image):
91
+ images.append(image)
92
+ print(f"✅ Scene {i+1} added to image list.")
93
+ else:
94
+ print(f"⚠️ Scene {i+1} is not a valid image object.")
95
+
96
+ except Exception as e:
97
+ print(f"❌ Exception during scene {i+1}: {e}")
98
+ traceback.print_exc()
99
+
100
+ print(f"📦 Returning {len(images)} image(s)")
101
+ return images
102
+
103
+ def generate_storybook_from_textbox(json_input_text):
104
+ try:
105
+ data = json.loads(json_input_text)
106
+ return generate_storybook(data)
107
+ except Exception as e:
108
+ print(f"❌ JSON parse or generation error: {e}")
109
+ traceback.print_exc()
110
+ return [f"Error: {str(e)}"]
111
+
112
+ iface = gr.Interface(
113
+ fn=generate_storybook_from_textbox,
114
+ inputs=gr.Textbox(label="Input JSON", lines=20, placeholder="{...}"),
115
+ outputs=gr.Gallery(label="Generated Story Scenes", show_label=True, columns=1),
116
+ title="AI Storybook Generator (Render Fix)",
117
+ description="Paste JSON to generate story scenes with fixed image rendering."
118
+ )
119
+
120
+ iface.launch()