yukee1992 commited on
Commit
6bf397f
Β·
verified Β·
1 Parent(s): b05c170

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -130
app.py CHANGED
@@ -18,6 +18,7 @@ import random
18
  import gc
19
  import psutil
20
  import threading
 
21
 
22
  # External OCI API URL
23
  OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space"
@@ -51,7 +52,7 @@ class StorybookRequest(BaseModel):
51
  model_choice: str = "sdxl"
52
  style: str = "childrens_book"
53
 
54
- # MODEL SELECTION - SDXL handles longer prompts better
55
  MODEL_CHOICES = {
56
  "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
57
  "sdxl-turbo": "stabilityai/sdxl-turbo",
@@ -59,10 +60,11 @@ MODEL_CHOICES = {
59
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
60
  }
61
 
62
- # GLOBAL MODEL CACHE
63
  model_cache = {}
64
  current_model_name = None
65
  current_pipe = None
 
66
 
67
  # Character consistency tracking
68
  character_descriptions = {}
@@ -80,151 +82,139 @@ def cleanup_memory():
80
  gc.collect()
81
  if torch.cuda.is_available():
82
  torch.cuda.empty_cache()
83
- print("🧹 Memory cleaned up")
84
 
85
  def load_model(model_name="sdxl"):
 
86
  global model_cache, current_model_name, current_pipe
87
 
88
- if model_name in model_cache:
89
- current_pipe = model_cache[model_name]
90
- current_model_name = model_name
91
- return current_pipe
92
-
93
- print(f"πŸ”„ Loading model: {model_name}")
94
- try:
95
- if model_name in ["sdxl", "sdxl-turbo"]:
96
- model_id = MODEL_CHOICES[model_name]
97
- pipe = StableDiffusionXLPipeline.from_pretrained(
98
- model_id,
99
- torch_dtype=torch.float32,
100
- use_safetensors=True,
101
- safety_checker=None,
102
- requires_safety_checker=False
103
- )
104
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
105
- else:
106
- model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
107
- pipe = StableDiffusionPipeline.from_pretrained(
108
- model_id,
109
- torch_dtype=torch.float32,
110
- safety_checker=None,
111
- requires_safety_checker=False
112
- )
113
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
114
 
115
- pipe = pipe.to("cpu")
116
- model_cache[model_name] = pipe
117
- current_pipe = pipe
118
- current_model_name = model_name
119
-
120
- print(f"βœ… Model loaded: {model_name}")
121
- return pipe
122
-
123
- except Exception as e:
124
- print(f"❌ Model loading failed: {e}")
125
- pipe = StableDiffusionPipeline.from_pretrained(
126
- "runwayml/stable-diffusion-v1-5",
127
- torch_dtype=torch.float32
128
- ).to("cpu")
129
- model_cache[model_name] = pipe
130
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Initialize default model
133
  print("πŸš€ Initializing Storybook Generator...")
134
  current_pipe = load_model("sdxl")
135
  print("βœ… Model loaded and ready!")
136
 
137
- # TRUE UNLIMITED PROMPT SOLUTION
138
- def create_compressed_prompt(scene_visual, characters, style="childrens_book", page_number=1):
139
- """
140
- Create a compressed but comprehensive prompt that fits within token limits
141
- while preserving ALL important information
142
- """
143
- # Extract ONLY the most critical character features
144
- character_features = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if characters:
146
  for char in characters:
147
- if hasattr(char, 'description'):
 
148
  desc = char.description
149
  elif isinstance(char, dict):
 
150
  desc = char.get('description', '')
151
  else:
152
- desc = str(char)
153
 
154
- # Extract key features: age, appearance, clothing
155
  import re
156
- # Get age if mentioned
157
- age_match = re.search(r'(\d+)[\- ]?year[\- ]?old', desc, re.IGNORECASE)
158
- age = f"{age_match.group(1)} year old" if age_match else ""
159
-
160
  # Get species/type
161
- species_match = re.search(r'(rabbit|hedgehog|bird|dog|cat|fox|bear|dragon|unicorn|human|girl|boy)', desc, re.IGNORECASE)
162
  species = species_match.group(1) if species_match else "character"
163
 
164
- # Get color/main features
165
- color_match = re.search(r'(blonde|brown|black|white|blue|red|green|yellow|golden|silver)', desc, re.IGNORECASE)
166
  color = color_match.group(1) if color_match else ""
167
 
168
- # Get key accessories
169
- accessories = []
170
- if 'glasses' in desc.lower(): accessories.append('glasses')
171
- if 'dress' in desc.lower(): accessories.append('dress')
172
- if 'hat' in desc.lower(): accessories.append('hat')
173
- if 'satchel' in desc.lower(): accessories.append('satchel')
174
-
175
- # Build compressed description
176
- compressed_desc = f"{age} {color} {species}".strip()
177
- if accessories:
178
- compressed_desc += f" with {', '.join(accessories)}"
179
-
180
- character_features.append(compressed_desc)
181
-
182
- # Build scene context
183
- continuity_context = f"scene {page_number}" if page_number > 1 else ""
184
-
185
- # Style templates (compressed)
186
- style_presets = {
187
- "childrens_book": "children's book illustration, watercolor, whimsical",
188
- "realistic": "photorealistic, professional photography",
189
- "fantasy": "fantasy art, digital painting, magical",
190
- "anime": "anime style, clean lines, vibrant colors"
191
- }
192
 
193
- style_prompt = style_presets.get(style, style_presets["childrens_book"])
 
 
194
 
195
- # Build the final compressed prompt
196
- compressed_prompt = f"{continuity_context} {scene_visual}"
197
 
198
- if character_features:
199
- compressed_prompt += f". Characters: {', '.join(character_features)}"
200
-
201
- compressed_prompt += f". Style: {style_prompt}. masterpiece, best quality, 4K"
202
-
203
- # Ensure it's within reasonable length
204
- words = compressed_prompt.split()
205
  if len(words) > 60:
206
- compressed_prompt = ' '.join(words[:60]) + '...'
207
 
208
- return compressed_prompt
209
 
210
  def enhance_prompt(scene_visual, characters, style="childrens_book", page_number=1):
211
- """
212
- Create optimized prompt that preserves essence while fitting token limits
213
- """
214
- # Use compressed prompt for the actual generation
215
- main_prompt = create_compressed_prompt(scene_visual, characters, style, page_number)
216
 
217
- print(f"πŸ“ Compressed prompt: {main_prompt}")
218
  print(f"πŸ“ Length: {len(main_prompt.split())} words")
219
 
220
  # Negative prompt
221
  negative_prompt = (
222
- "blurry, low quality, ugly, deformed, poorly drawn, bad anatomy, "
223
- "wrong anatomy, extra limb, missing limb, floating limbs, "
224
- "disconnected limbs, mutation, mutated, disgusting, bad art, "
225
- "beginner, amateur, distorted, watermark, signature, text, username, "
226
- "multiple people, crowd, group, different characters, inconsistent features, "
227
- "changed appearance, different face, altered features, low resolution"
228
  )
229
 
230
  return main_prompt, negative_prompt
@@ -274,18 +264,22 @@ def get_character_seed(story_title, character_name, page_number):
274
  return character_seeds[story_title][seed_key]
275
 
276
  def generate_storybook_page(scene_visual, story_title, sequence_number, scene_text, characters, model_choice="sdxl", style="childrens_book"):
 
277
  global current_pipe, current_model_name
278
 
279
  try:
 
280
  if model_choice != current_model_name:
 
281
  current_pipe = load_model(model_choice)
 
 
282
 
283
  enhanced_prompt, negative_prompt = enhance_prompt(
284
  scene_visual, characters, style, sequence_number
285
  )
286
 
287
  print(f"πŸ“– Generating page {sequence_number}")
288
- print(f"πŸ“ Using prompt: {enhanced_prompt}")
289
 
290
  if characters:
291
  char_names = []
@@ -308,17 +302,23 @@ def generate_storybook_page(scene_visual, story_title, sequence_number, scene_te
308
  scene_seed = hash(f"{story_title}_{sequence_number}") % 1000000
309
  generator.manual_seed(scene_seed)
310
 
311
- # Generate image
 
 
 
312
  image = current_pipe(
313
  prompt=enhanced_prompt,
314
  negative_prompt=negative_prompt,
315
- num_inference_steps=35,
316
- guidance_scale=7.5,
317
- width=768,
318
- height=768,
319
  generator=generator
320
  ).images[0]
321
 
 
 
 
322
  save_status = save_complete_storybook_page(image, story_title, sequence_number, scene_text)
323
  return image, save_status
324
 
@@ -326,48 +326,52 @@ def generate_storybook_page(scene_visual, story_title, sequence_number, scene_te
326
  return None, f"❌ Generation failed: {str(e)}"
327
 
328
  def batch_generate_complete_storybook(story_title, scenes_data, characters, model_choice="sdxl", style="childrens_book"):
329
- global character_descriptions, current_pipe
 
330
 
331
  results = []
332
  status_messages = []
333
 
334
- print(f"πŸ“š Starting batch generation: {story_title}")
335
  print(f"πŸ“– Pages: {len(scenes_data)}")
336
  print(f"πŸ‘€ Characters: {len(characters)}")
337
 
338
- if characters:
339
- character_descriptions[story_title] = characters
340
-
341
  current_pipe = load_model(model_choice)
342
- start_time = time.time()
343
 
344
  for i, scene_data in enumerate(scenes_data, 1):
345
  try:
346
- if i % 2 == 0:
347
- cleanup_memory()
348
-
349
  scene_visual = scene_data.get('visual', '')
350
  scene_text = scene_data.get('text', '')
351
 
352
  print(f"πŸ”„ Generating page {i}/{len(scenes_data)}...")
 
 
353
  image, status = generate_storybook_page(
354
  scene_visual, story_title, i, scene_text, characters, model_choice, style
355
  )
356
 
 
 
 
357
  if image:
358
  results.append((f"Page {i}", image, scene_text))
359
  status_messages.append(f"Page {i}: {status}")
360
 
361
- if i < len(scenes_data):
362
- time.sleep(2)
 
363
 
364
  except Exception as e:
365
  error_msg = f"❌ Failed page {i}: {str(e)}"
366
  print(error_msg)
367
  status_messages.append(error_msg)
368
 
369
- total_time = time.time() - start_time
370
  print(f"βœ… Batch completed in {total_time:.2f} seconds")
 
371
 
372
  return results, "\n".join(status_messages)
373
 
@@ -380,7 +384,15 @@ async def api_generate_storybook(request: StorybookRequest):
380
 
381
  start_time = time.time()
382
  scenes_data = [{"visual": scene.visual, "text": scene.text} for scene in request.scenes]
383
- characters_dict = [char.dict() for char in request.characters]
 
 
 
 
 
 
 
 
384
 
385
  results, status = batch_generate_complete_storybook(
386
  request.story_title,
 
18
  import gc
19
  import psutil
20
  import threading
21
+ from functools import lru_cache
22
 
23
  # External OCI API URL
24
  OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space"
 
52
  model_choice: str = "sdxl"
53
  style: str = "childrens_book"
54
 
55
+ # MODEL SELECTION
56
  MODEL_CHOICES = {
57
  "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
58
  "sdxl-turbo": "stabilityai/sdxl-turbo",
 
60
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
61
  }
62
 
63
+ # GLOBAL MODEL CACHE with proper locking
64
  model_cache = {}
65
  current_model_name = None
66
  current_pipe = None
67
+ model_lock = threading.Lock()
68
 
69
  # Character consistency tracking
70
  character_descriptions = {}
 
82
  gc.collect()
83
  if torch.cuda.is_available():
84
  torch.cuda.empty_cache()
 
85
 
86
  def load_model(model_name="sdxl"):
87
+ """Thread-safe model loading with proper caching"""
88
  global model_cache, current_model_name, current_pipe
89
 
90
+ with model_lock:
91
+ if model_name in model_cache:
92
+ print(f"βœ… Using cached model: {model_name}")
93
+ current_pipe = model_cache[model_name]
94
+ current_model_name = model_name
95
+ return current_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ print(f"πŸ”„ Loading model: {model_name}")
98
+ try:
99
+ if model_name in ["sdxl", "sdxl-turbo"]:
100
+ model_id = MODEL_CHOICES[model_name]
101
+ pipe = StableDiffusionXLPipeline.from_pretrained(
102
+ model_id,
103
+ torch_dtype=torch.float32,
104
+ use_safetensors=True,
105
+ safety_checker=None,
106
+ requires_safety_checker=False
107
+ )
108
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
109
+ else:
110
+ model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
111
+ pipe = StableDiffusionPipeline.from_pretrained(
112
+ model_id,
113
+ torch_dtype=torch.float32,
114
+ safety_checker=None,
115
+ requires_safety_checker=False
116
+ )
117
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
118
+
119
+ pipe = pipe.to("cpu")
120
+ model_cache[model_name] = pipe
121
+ current_pipe = pipe
122
+ current_model_name = model_name
123
+
124
+ print(f"βœ… Model loaded and cached: {model_name}")
125
+ return pipe
126
+
127
+ except Exception as e:
128
+ print(f"❌ Model loading failed: {e}")
129
+ pipe = StableDiffusionPipeline.from_pretrained(
130
+ "runwayml/stable-diffusion-v1-5",
131
+ torch_dtype=torch.float32
132
+ ).to("cpu")
133
+ model_cache[model_name] = pipe
134
+ return pipe
135
 
136
  # Initialize default model
137
  print("πŸš€ Initializing Storybook Generator...")
138
  current_pipe = load_model("sdxl")
139
  print("βœ… Model loaded and ready!")
140
 
141
+ # OPTIMIZED PROMPT COMPRESSION
142
+ @lru_cache(maxsize=100)
143
+ def compress_prompt(text, style="childrens_book"):
144
+ """Cache compressed prompts to avoid recomputation"""
145
+ # Simple compression: remove redundant words and shorten
146
+ words = text.split()
147
+ if len(words) <= 50:
148
+ return text
149
+
150
+ # Keep first 40 words (most important part) and key descriptors
151
+ compressed = ' '.join(words[:40])
152
+
153
+ # Add style context
154
+ style_context = {
155
+ "childrens_book": "children's book style",
156
+ "realistic": "realistic style",
157
+ "fantasy": "fantasy style",
158
+ "anime": "anime style"
159
+ }
160
+
161
+ return f"{compressed}... {style_context.get(style, '')} masterpiece 4K"
162
+
163
+ def create_optimized_prompt(scene_visual, characters, style="childrens_book", page_number=1):
164
+ """Create optimized prompt within token limits"""
165
+ # Compress the scene visual
166
+ scene_compressed = compress_prompt(scene_visual, style)
167
+
168
+ # Extract character essentials
169
+ char_descriptors = []
170
  if characters:
171
  for char in characters:
172
+ if hasattr(char, 'name'):
173
+ name = char.name
174
  desc = char.description
175
  elif isinstance(char, dict):
176
+ name = char.get('name', 'Unknown')
177
  desc = char.get('description', '')
178
  else:
179
+ continue
180
 
181
+ # Extract key features
182
  import re
 
 
 
 
183
  # Get species/type
184
+ species_match = re.search(r'(rabbit|hedgehog|bird|dog|cat|fox|bear|dragon|human|girl|boy)', desc, re.IGNORECASE)
185
  species = species_match.group(1) if species_match else "character"
186
 
187
+ # Get color
188
+ color_match = re.search(r'(white|black|brown|blue|red|green|yellow|golden|pink)', desc, re.IGNORECASE)
189
  color = color_match.group(1) if color_match else ""
190
 
191
+ char_descriptors.append(f"{color} {species}".strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # Build the final prompt
194
+ continuity = f"scene {page_number} " if page_number > 1 else ""
195
+ chars_text = f"Characters: {', '.join(char_descriptors)}. " if char_descriptors else ""
196
 
197
+ final_prompt = f"{continuity}{scene_compressed}. {chars_text}masterpiece best quality 4K"
 
198
 
199
+ # Ensure it's under 60 words
200
+ words = final_prompt.split()
 
 
 
 
 
201
  if len(words) > 60:
202
+ final_prompt = ' '.join(words[:60])
203
 
204
+ return final_prompt
205
 
206
  def enhance_prompt(scene_visual, characters, style="childrens_book", page_number=1):
207
+ """Create optimized prompt"""
208
+ main_prompt = create_optimized_prompt(scene_visual, characters, style, page_number)
 
 
 
209
 
210
+ print(f"πŸ“ Optimized prompt: {main_prompt}")
211
  print(f"πŸ“ Length: {len(main_prompt.split())} words")
212
 
213
  # Negative prompt
214
  negative_prompt = (
215
+ "blurry, low quality, ugly, deformed, bad anatomy, "
216
+ "watermark, signature, text, username, multiple people, "
217
+ "inconsistent features, low resolution"
 
 
 
218
  )
219
 
220
  return main_prompt, negative_prompt
 
264
  return character_seeds[story_title][seed_key]
265
 
266
  def generate_storybook_page(scene_visual, story_title, sequence_number, scene_text, characters, model_choice="sdxl", style="childrens_book"):
267
+ """Generate a single page - OPTIMIZED VERSION"""
268
  global current_pipe, current_model_name
269
 
270
  try:
271
+ # ONLY load model if different from current
272
  if model_choice != current_model_name:
273
+ print(f"πŸ”„ Switching to model: {model_choice}")
274
  current_pipe = load_model(model_choice)
275
+ else:
276
+ print(f"βœ… Using already loaded model: {model_choice}")
277
 
278
  enhanced_prompt, negative_prompt = enhance_prompt(
279
  scene_visual, characters, style, sequence_number
280
  )
281
 
282
  print(f"πŸ“– Generating page {sequence_number}")
 
283
 
284
  if characters:
285
  char_names = []
 
302
  scene_seed = hash(f"{story_title}_{sequence_number}") % 1000000
303
  generator.manual_seed(scene_seed)
304
 
305
+ # Generate image with optimized parameters
306
+ print("⏳ Starting image generation...")
307
+ start_time = time.time()
308
+
309
  image = current_pipe(
310
  prompt=enhanced_prompt,
311
  negative_prompt=negative_prompt,
312
+ num_inference_steps=25, # Reduced from 35 for speed
313
+ guidance_scale=7.0,
314
+ width=512, # Reduced from 768 for speed
315
+ height=512,
316
  generator=generator
317
  ).images[0]
318
 
319
+ gen_time = time.time() - start_time
320
+ print(f"βœ… Image generated in {gen_time:.1f} seconds")
321
+
322
  save_status = save_complete_storybook_page(image, story_title, sequence_number, scene_text)
323
  return image, save_status
324
 
 
326
  return None, f"❌ Generation failed: {str(e)}"
327
 
328
  def batch_generate_complete_storybook(story_title, scenes_data, characters, model_choice="sdxl", style="childrens_book"):
329
+ """Batch generation with significant optimizations"""
330
+ global current_pipe
331
 
332
  results = []
333
  status_messages = []
334
 
335
+ print(f"πŸ“š Starting OPTIMIZED batch generation: {story_title}")
336
  print(f"πŸ“– Pages: {len(scenes_data)}")
337
  print(f"πŸ‘€ Characters: {len(characters)}")
338
 
339
+ # Load model ONCE at the beginning
340
+ print(f"πŸ”§ Loading model once for entire batch...")
 
341
  current_pipe = load_model(model_choice)
342
+ batch_start_time = time.time()
343
 
344
  for i, scene_data in enumerate(scenes_data, 1):
345
  try:
 
 
 
346
  scene_visual = scene_data.get('visual', '')
347
  scene_text = scene_data.get('text', '')
348
 
349
  print(f"πŸ”„ Generating page {i}/{len(scenes_data)}...")
350
+ page_start_time = time.time()
351
+
352
  image, status = generate_storybook_page(
353
  scene_visual, story_title, i, scene_text, characters, model_choice, style
354
  )
355
 
356
+ page_time = time.time() - page_start_time
357
+ print(f"⏰ Page {i} completed in {page_time:.1f} seconds")
358
+
359
  if image:
360
  results.append((f"Page {i}", image, scene_text))
361
  status_messages.append(f"Page {i}: {status}")
362
 
363
+ # Clean memory every 3 pages
364
+ if i % 3 == 0:
365
+ cleanup_memory()
366
 
367
  except Exception as e:
368
  error_msg = f"❌ Failed page {i}: {str(e)}"
369
  print(error_msg)
370
  status_messages.append(error_msg)
371
 
372
+ total_time = time.time() - batch_start_time
373
  print(f"βœ… Batch completed in {total_time:.2f} seconds")
374
+ print(f"πŸ“Š Average: {total_time/len(scenes_data):.1f} seconds per page")
375
 
376
  return results, "\n".join(status_messages)
377
 
 
384
 
385
  start_time = time.time()
386
  scenes_data = [{"visual": scene.visual, "text": scene.text} for scene in request.scenes]
387
+
388
+ # Convert characters to dict ONCE
389
+ characters_dict = []
390
+ for char in request.characters:
391
+ if hasattr(char, 'dict'):
392
+ characters_dict.append(char.dict())
393
+ else:
394
+ characters_dict.append({"name": getattr(char, 'name', 'Unknown'),
395
+ "description": getattr(char, 'description', '')})
396
 
397
  results, status = batch_generate_complete_storybook(
398
  request.story_title,