yukee1992 commited on
Commit
2111d34
·
verified ·
1 Parent(s): 5d8ce04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -285
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
4
  from PIL import Image
5
  import io
6
  import requests
@@ -18,6 +18,8 @@ import random
18
  import gc
19
  import psutil
20
  import threading
 
 
21
 
22
  # External OCI API URL
23
  OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space"
@@ -48,163 +50,268 @@ class StorybookRequest(BaseModel):
48
  story_title: str
49
  scenes: List[StoryScene]
50
  characters: List[CharacterDescription] = []
51
- model_choice: str = "dreamshaper-8"
52
  style: str = "childrens_book"
53
 
54
- # HIGH-QUALITY MODEL SELECTION
55
  MODEL_CHOICES = {
 
 
56
  "dreamshaper-8": "lykon/dreamshaper-8",
57
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
58
- "anything-v5": "andite/anything-v5.0",
59
- "openjourney": "prompthero/openjourney",
60
- "sd-2.1": "stabilityai/stable-diffusion-2-1",
61
  }
62
 
63
- # GLOBAL MODEL CACHE - Load once, reuse forever
64
  model_cache = {}
65
  current_model_name = None
66
  current_pipe = None
67
 
68
  # Character consistency tracking
69
  character_descriptions = {}
70
- character_seeds = {} # Store seeds for consistent character generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Memory monitoring function
73
  def monitor_memory():
74
- """Monitor current memory usage"""
75
  try:
76
  process = psutil.Process()
77
- memory_usage = process.memory_info().rss / 1024 / 1024 # MB
78
  print(f"📊 Memory usage: {memory_usage:.2f} MB")
79
  return memory_usage
80
  except:
81
- print("⚠️ Could not monitor memory (psutil not available)")
82
  return 0
83
 
84
- # Memory cleanup function
85
  def cleanup_memory():
86
- """Clean up memory and cache"""
87
  gc.collect()
88
  if torch.cuda.is_available():
89
  torch.cuda.empty_cache()
90
  print("🧹 Memory cleaned up")
91
 
92
- def load_model(model_name="dreamshaper-8"):
93
- """Load model into global cache - runs only once per model"""
94
  global model_cache, current_model_name, current_pipe
95
 
96
- # Return cached model if already loaded
97
  if model_name in model_cache:
98
- print(f"✅ Using cached model: {model_name}")
99
  current_pipe = model_cache[model_name]
100
  current_model_name = model_name
101
  return current_pipe
102
 
103
- print(f"🔄 Loading model for the first time: {model_name}")
104
  try:
105
- model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
106
-
107
- pipe = StableDiffusionPipeline.from_pretrained(
108
- model_id,
109
- torch_dtype=torch.float32,
110
- safety_checker=None,
111
- requires_safety_checker=False
112
- )
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Use better scheduler for quality
115
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
116
  pipe = pipe.to("cpu")
117
-
118
- # Cache the model for future use
119
  model_cache[model_name] = pipe
120
  current_pipe = pipe
121
  current_model_name = model_name
122
 
123
- print(f"✅ Model loaded and cached: {model_name}")
124
- monitor_memory()
125
  return pipe
126
 
127
  except Exception as e:
128
  print(f"❌ Model loading failed: {e}")
129
- # Fallback to SD 1.5
130
  pipe = StableDiffusionPipeline.from_pretrained(
131
  "runwayml/stable-diffusion-v1-5",
132
- torch_dtype=torch.float32,
133
- safety_checker=None,
134
- requires_safety_checker=False
135
  ).to("cpu")
136
  model_cache[model_name] = pipe
137
- current_pipe = pipe
138
  return pipe
139
 
140
- # Load the default model once at startup
141
  print("🚀 Initializing Storybook Generator...")
142
- current_pipe = load_model("dreamshaper-8")
143
- print("✅ Default model loaded and ready!")
144
- monitor_memory()
145
 
146
- # PROFESSIONAL PROMPT ENGINEERING
147
- def enhance_prompt(prompt, style="childrens_book"):
148
- """Transform basic prompts into professional-grade prompts"""
149
-
150
- style_templates = {
151
- "childrens_book": [
152
- "masterpiece, best quality, 4K, ultra detailed, children's book illustration",
153
- "watercolor painting, whimsical, cute, charming, storybook style",
154
- "vibrant colors, soft lighting, magical, enchanting, dreamlike",
155
- "Pixar style, Disney animation, high detail, professional artwork"
156
- ],
157
- "realistic": [
158
- "photorealistic, 8K, ultra detailed, professional photography",
159
- "sharp focus, studio lighting, high resolution, intricate details",
160
- "realistic textures, natural lighting, cinematic quality"
161
- ],
162
- "fantasy": [
163
- "epic fantasy art, digital painting, concept art, trending on artstation",
164
- "magical, mystical, ethereal, otherworldly, fantasy illustration",
165
- "dynamic composition, dramatic lighting, highly detailed"
166
- ],
167
- "anime": [
168
- "anime style, Japanese animation, high quality, detailed artwork",
169
- "beautiful anime illustration, vibrant colors, clean lines",
170
- "studio ghibli style, makoto shinkai, professional anime art"
171
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  }
173
 
174
- templates = style_templates.get(style, style_templates["childrens_book"])
175
- style_prompt = templates[0]
176
 
177
- enhanced = f"{style_prompt}, {prompt}"
 
 
 
 
 
 
 
 
 
178
 
179
- quality_boosters = [
180
- "intricate details", "beautiful composition", "perfect lighting",
181
- "professional artwork", "award winning", "trending on artstation"
182
- ]
183
 
184
- boosters = random.sample(quality_boosters, 2)
185
- enhanced += ", " + ", ".join(boosters)
186
 
 
 
 
 
 
 
 
 
 
187
  negative_prompt = (
188
- "blurry, low quality, low resolution, ugly, deformed, poorly drawn, "
189
- "bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, "
190
- "disconnected limbs, mutation, mutated, ugly, disgusting, bad art, "
191
- "beginner, amateur, distorted, watermark, signature, text, username"
 
 
 
192
  )
193
 
194
- return enhanced, negative_prompt
195
 
196
  def save_complete_storybook_page(image, story_title, sequence_number, scene_text):
197
- """Save image AND text to OCI with organized structure"""
198
  try:
199
- # Convert image to bytes
200
  img_bytes = io.BytesIO()
201
  image.save(img_bytes, format='PNG')
202
  img_data = img_bytes.getvalue()
203
 
204
- # Clean title for filenames
205
  clean_title = re.sub(r'[^a-zA-Z0-9_\-]', '', story_title.strip().replace(' ', '_'))
206
-
207
- # Create filenames
208
  image_filename = f"page_{sequence_number:03d}_{clean_title}.png"
209
  text_filename = f"page_{sequence_number:03d}_{clean_title}.txt"
210
 
@@ -229,112 +336,90 @@ def save_complete_storybook_page(image, story_title, sequence_number, scene_text
229
  except Exception as e:
230
  return f"❌ Save failed: {str(e)}"
231
 
232
- def enhance_with_character_context(scene_visual, story_title, characters):
233
- """Add character descriptions to maintain consistency"""
234
- if characters:
235
- character_context = " ".join([f"{char.name}: {char.description}" for char in characters])
236
- return f"Character descriptions: {character_context}. {scene_visual}"
237
- return scene_visual
238
-
239
- def get_character_seed(story_title, character_name):
240
- """Get consistent seed for character generation"""
241
  if story_title not in character_seeds:
242
  character_seeds[story_title] = {}
243
 
244
- if character_name not in character_seeds[story_title]:
245
- # Generate a stable seed based on character name and story title
246
- seed_value = hash(f"{story_title}_{character_name}") % 1000000
247
- character_seeds[story_title][character_name] = seed_value
248
- print(f"🌱 Seed for {character_name}: {seed_value}")
 
249
 
250
- return character_seeds[story_title][character_name]
251
 
252
- def generate_storybook_page(scene_visual, story_title, sequence_number, scene_text, characters, model_choice="dreamshaper-8", style="childrens_book"):
253
- """Generate a storybook page with character consistency"""
254
  global current_pipe, current_model_name
255
 
256
  try:
257
- # Switch model if different from current - BUT DON'T RELOAD UNLESS NECESSARY
258
  if model_choice != current_model_name:
259
- print(f"🔄 Switching to model: {model_choice}")
260
- current_pipe = load_model(model_choice) # This uses cached version if available
261
-
262
- # ENHANCE PROMPT WITH CHARACTER CONTEXT
263
- enhanced_visual = enhance_with_character_context(scene_visual, story_title, characters)
264
 
265
- # Add scene continuity context
266
- if sequence_number > 1:
267
- enhanced_visual = f"Scene {sequence_number}, maintain character consistency from previous scenes. {enhanced_visual}"
268
 
269
- enhanced_prompt, negative_prompt = enhance_prompt(enhanced_visual, style)
 
270
 
271
- print(f"📖 Generating page {sequence_number} for: {story_title}")
272
  if characters:
273
- print(f"👤 Characters: {[char.name for char in characters]}")
 
 
 
 
 
 
274
 
275
- # Use consistent seed for character generation
276
  generator = torch.Generator(device="cpu")
 
277
  if characters:
278
- # Use seed from main character for consistency
279
- main_char_seed = get_character_seed(story_title, characters[0].name)
 
280
  generator.manual_seed(main_char_seed)
281
- print(f"🌱 Using seed {main_char_seed} for character consistency")
282
  else:
283
- seed = int(time.time())
284
- generator.manual_seed(seed)
285
- print(f"🌱 Using timestamp seed {seed}")
286
 
287
- # Generate high-quality image - USE THE GLOBAL current_pipe
288
  image = current_pipe(
289
  prompt=enhanced_prompt,
290
  negative_prompt=negative_prompt,
291
- num_inference_steps=30,
292
- guidance_scale=8.5,
293
  width=768,
294
  height=768,
295
  generator=generator
296
  ).images[0]
297
 
298
- # Save both image and text
299
  save_status = save_complete_storybook_page(image, story_title, sequence_number, scene_text)
300
-
301
  return image, save_status
302
 
303
  except Exception as e:
304
  return None, f"❌ Generation failed: {str(e)}"
305
 
306
- def batch_generate_complete_storybook(story_title, scenes_data, characters, model_choice="dreamshaper-8", style="childrens_book"):
307
- """Generate complete storybook with memory management"""
308
  global character_descriptions, current_pipe
309
 
310
  results = []
311
  status_messages = []
312
 
313
- print(f"📚 Starting batch generation for: {story_title}")
314
- print(f"📖 Total pages: {len(scenes_data)}")
315
  print(f"👤 Characters: {len(characters)}")
316
- print(f"🎨 Using model: {model_choice}")
317
 
318
- # Initial memory check
319
- initial_memory = monitor_memory()
320
-
321
- # Store character descriptions for this story
322
  if characters:
323
  character_descriptions[story_title] = characters
324
- print(f"✅ Character context stored for {story_title}")
325
 
326
- # Load model once at the beginning
327
- print(f"🔧 Loading model for this storybook...")
328
  current_pipe = load_model(model_choice)
329
-
330
  start_time = time.time()
331
 
332
  for i, scene_data in enumerate(scenes_data, 1):
333
  try:
334
- # Clean memory every 2 pages
335
  if i % 2 == 0:
336
  cleanup_memory()
337
- monitor_memory()
338
 
339
  scene_visual = scene_data.get('visual', '')
340
  scene_text = scene_data.get('text', '')
@@ -348,43 +433,34 @@ def batch_generate_complete_storybook(story_title, scenes_data, characters, mode
348
  results.append((f"Page {i}", image, scene_text))
349
  status_messages.append(f"Page {i}: {status}")
350
 
 
 
 
351
  except Exception as e:
352
  error_msg = f"❌ Failed page {i}: {str(e)}"
353
  print(error_msg)
354
  status_messages.append(error_msg)
355
- # Continue with next page instead of stopping
356
 
357
  total_time = time.time() - start_time
358
- final_memory = monitor_memory()
359
-
360
- print(f"✅ Batch generation completed in {total_time:.2f} seconds")
361
- print(f"📊 Memory delta: {final_memory - initial_memory:.2f} MB")
362
 
363
  return results, "\n".join(status_messages)
364
 
365
- # FastAPI endpoint for n8n
366
  @app.post("/api/generate-storybook")
367
  async def api_generate_storybook(request: StorybookRequest):
368
- """API endpoint for n8n automation - OPTIMIZED with character consistency"""
369
  try:
370
- print(f"📚 Received storybook request: {request.story_title}")
371
- print(f"📖 Pages to generate: {len(request.scenes)}")
372
- print(f"👤 Characters received: {len(request.characters)}")
373
-
374
- if request.characters:
375
- for char in request.characters:
376
- print(f" - {char.name}: {char.description[:50]}...")
377
 
378
  start_time = time.time()
379
-
380
- # Convert to scene data format
381
  scenes_data = [{"visual": scene.visual, "text": scene.text} for scene in request.scenes]
 
382
 
383
- # Generate storybook (model loads only once)
384
  results, status = batch_generate_complete_storybook(
385
  request.story_title,
386
  scenes_data,
387
- request.characters,
388
  request.model_choice,
389
  request.style
390
  )
@@ -412,113 +488,24 @@ async def api_generate_storybook(request: StorybookRequest):
412
  except Exception as e:
413
  error_msg = f"Storybook generation failed: {str(e)}"
414
  print(f"❌ {error_msg}")
415
- import traceback
416
- traceback.print_exc()
417
  raise HTTPException(status_code=500, detail=error_msg)
418
 
419
- # Async processing endpoint for large batches
420
- @app.post("/api/generate-storybook-async")
421
- async def api_generate_storybook_async(request: StorybookRequest, background_tasks: BackgroundTasks):
422
- """Async endpoint that processes images in background with memory management"""
423
- try:
424
- # Store the request and return immediate response
425
- request_id = f"{request.story_title}_{int(time.time())}"
426
-
427
- # Start background task
428
- background_tasks.add_task(
429
- process_storybook_async,
430
- request_id,
431
- request.dict()
432
- )
433
-
434
- return {
435
- "status": "processing",
436
- "request_id": request_id,
437
- "message": f"Started processing {len(request.scenes)} pages for '{request.story_title}'",
438
- "estimated_time": f"Approximately {len(request.scenes) * 45} seconds"
439
- }
440
-
441
- except Exception as e:
442
- raise HTTPException(status_code=500, detail=str(e))
443
-
444
- def process_storybook_async(request_id, request_data):
445
- """Background task for async processing with memory management"""
446
- try:
447
- print(f"🔧 Starting async processing for request: {request_id}")
448
- print(f"📖 Pages to process: {len(request_data['scenes'])}")
449
-
450
- # Convert dictionary back to proper objects for character handling
451
- characters = []
452
- if 'characters' in request_data and request_data['characters']:
453
- # Convert character dicts back to CharacterDescription objects
454
- for char_dict in request_data['characters']:
455
- characters.append(CharacterDescription(**char_dict))
456
-
457
- # Initial memory check
458
- initial_memory = monitor_memory()
459
-
460
- for i, scene in enumerate(request_data['scenes']):
461
- try:
462
- print(f"🔄 Processing page {i+1}/{len(request_data['scenes'])} for request {request_id}")
463
-
464
- # Generate single page - pass the converted character objects
465
- image, status = generate_storybook_page(
466
- scene['visual'],
467
- request_data['story_title'],
468
- i+1,
469
- scene['text'],
470
- characters, # Pass the converted objects, not raw dicts
471
- request_data.get('model_choice', 'dreamshaper-8'),
472
- request_data.get('style', 'childrens_book')
473
- )
474
-
475
- print(f"✅ Page {i+1} completed: {status}")
476
-
477
- # Clean memory after each page
478
- cleanup_memory()
479
- current_memory = monitor_memory()
480
-
481
- # Add delay between pages to prevent overload
482
- if i < len(request_data['scenes']) - 1: # Don't sleep after last page
483
- sleep_time = 5 # 5 second delay between pages
484
- print(f"⏳ Waiting {sleep_time} seconds before next page...")
485
- time.sleep(sleep_time)
486
-
487
- except Exception as e:
488
- error_msg = f"❌ Failed page {i+1}: {str(e)}"
489
- print(error_msg)
490
- # Continue with next page
491
- continue
492
-
493
- final_memory = monitor_memory()
494
- print(f"✅ Completed async processing for {request_id}")
495
- print(f"📊 Total memory change: {final_memory - initial_memory:.2f} MB")
496
-
497
- except Exception as e:
498
- print(f"❌ Async processing failed for {request_id}: {e}")
499
-
500
- # Health check endpoint with memory info
501
  @app.get("/api/health")
502
  async def health_check():
503
- memory_info = monitor_memory()
504
  return {
505
  "status": "healthy",
506
  "service": "Storybook Generator API",
507
  "timestamp": datetime.now().isoformat(),
508
- "memory_usage_mb": round(memory_info, 2),
509
  "models_loaded": list(model_cache.keys()),
510
- "current_model": current_model_name,
511
- "cached_models_count": len(model_cache),
512
- "stories_tracked": len(character_descriptions)
513
  }
514
 
515
- # Gradio Interface Functions
516
  def generate_single_page(prompt, story_title, scene_text, model_choice, style):
517
- """Generate a single page for Gradio interface"""
518
  if not prompt or not story_title:
519
  return None, "❌ Please enter both scene description and story title"
520
 
521
- # Ensure model is loaded
522
  global current_pipe
523
  if current_model_name != model_choice:
524
  current_pipe = load_model(model_choice)
@@ -528,25 +515,18 @@ def generate_single_page(prompt, story_title, scene_text, model_choice, style):
528
  )
529
  return image, status
530
 
531
- # Create the Gradio interface
532
  with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
533
  gr.Markdown("# 📚 Storybook Generator")
534
  gr.Markdown("Create beautiful storybooks with consistent characters")
535
 
536
  with gr.Row():
537
  with gr.Column(scale=1):
538
- story_title_input = gr.Textbox(
539
- label="Story Title",
540
- placeholder="Enter your story title...",
541
- lines=1
542
- )
543
-
544
  model_choice = gr.Dropdown(
545
  label="AI Model",
546
  choices=list(MODEL_CHOICES.keys()),
547
- value="dreamshaper-8"
548
  )
549
-
550
  style_choice = gr.Dropdown(
551
  label="Art Style",
552
  choices=["childrens_book", "realistic", "fantasy", "anime"],
@@ -554,18 +534,8 @@ with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
554
  )
555
 
556
  with gr.Column(scale=2):
557
- prompt_input = gr.Textbox(
558
- label="Visual Description",
559
- placeholder="Describe the scene for image generation...",
560
- lines=3
561
- )
562
-
563
- text_input = gr.Textbox(
564
- label="Story Text (Optional)",
565
- placeholder="Enter the story text for this page...",
566
- lines=2
567
- )
568
-
569
  generate_btn = gr.Button("✨ Generate Single Page", variant="primary")
570
  image_output = gr.Image(label="Generated Page", height=400)
571
  status_output = gr.Textbox(label="Status", interactive=False)
@@ -576,21 +546,9 @@ with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
576
  outputs=[image_output, status_output]
577
  )
578
 
579
- # Mount Gradio app to FastAPI
580
  app = gr.mount_gradio_app(app, demo, path="/")
581
 
582
- # For Hugging Face Spaces deployment
583
- def get_app():
584
- return app
585
-
586
  if __name__ == "__main__":
587
  print("🚀 Starting Storybook Generator API...")
588
- print("📚 Available models:", list(MODEL_CHOICES.keys()))
589
- print("🌐 API endpoints:")
590
- print(" - POST /api/generate-storybook")
591
- print(" - POST /api/generate-storybook-async (for large batches)")
592
- print(" - GET /api/health")
593
- print(" - GET / (Gradio UI)")
594
-
595
  import uvicorn
596
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
4
  from PIL import Image
5
  import io
6
  import requests
 
18
  import gc
19
  import psutil
20
  import threading
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+ import numpy as np
23
 
24
  # External OCI API URL
25
  OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space"
 
50
  story_title: str
51
  scenes: List[StoryScene]
52
  characters: List[CharacterDescription] = []
53
+ model_choice: str = "sdxl"
54
  style: str = "childrens_book"
55
 
56
+ # MODEL SELECTION
57
  MODEL_CHOICES = {
58
+ "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
59
+ "sdxl-turbo": "stabilityai/sdxl-turbo",
60
  "dreamshaper-8": "lykon/dreamshaper-8",
61
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
 
 
 
62
  }
63
 
64
+ # GLOBAL MODEL CACHE
65
  model_cache = {}
66
  current_model_name = None
67
  current_pipe = None
68
 
69
  # Character consistency tracking
70
  character_descriptions = {}
71
+ character_seeds = {}
72
+
73
+ # CLIP tokenizer for long prompt handling
74
+ clip_tokenizer = None
75
+ clip_model = None
76
+
77
+ def initialize_clip():
78
+ """Initialize CLIP for long prompt processing"""
79
+ global clip_tokenizer, clip_model
80
+ try:
81
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
82
+ clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
83
+ print("✅ CLIP model loaded for long prompt processing")
84
+ except Exception as e:
85
+ print(f"❌ CLIP loading failed: {e}")
86
 
87
  # Memory monitoring function
88
  def monitor_memory():
 
89
  try:
90
  process = psutil.Process()
91
+ memory_usage = process.memory_info().rss / 1024 / 1024
92
  print(f"📊 Memory usage: {memory_usage:.2f} MB")
93
  return memory_usage
94
  except:
 
95
  return 0
96
 
 
97
  def cleanup_memory():
 
98
  gc.collect()
99
  if torch.cuda.is_available():
100
  torch.cuda.empty_cache()
101
  print("🧹 Memory cleaned up")
102
 
103
+ def load_model(model_name="sdxl"):
 
104
  global model_cache, current_model_name, current_pipe
105
 
 
106
  if model_name in model_cache:
 
107
  current_pipe = model_cache[model_name]
108
  current_model_name = model_name
109
  return current_pipe
110
 
111
+ print(f"🔄 Loading model: {model_name}")
112
  try:
113
+ if model_name in ["sdxl", "sdxl-turbo"]:
114
+ model_id = MODEL_CHOICES[model_name]
115
+ pipe = StableDiffusionXLPipeline.from_pretrained(
116
+ model_id,
117
+ torch_dtype=torch.float32,
118
+ use_safetensors=True,
119
+ safety_checker=None,
120
+ requires_safety_checker=False
121
+ )
122
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
123
+ else:
124
+ model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
125
+ pipe = StableDiffusionPipeline.from_pretrained(
126
+ model_id,
127
+ torch_dtype=torch.float32,
128
+ safety_checker=None,
129
+ requires_safety_checker=False
130
+ )
131
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
132
 
 
 
133
  pipe = pipe.to("cpu")
 
 
134
  model_cache[model_name] = pipe
135
  current_pipe = pipe
136
  current_model_name = model_name
137
 
138
+ print(f"✅ Model loaded: {model_name}")
 
139
  return pipe
140
 
141
  except Exception as e:
142
  print(f"❌ Model loading failed: {e}")
 
143
  pipe = StableDiffusionPipeline.from_pretrained(
144
  "runwayml/stable-diffusion-v1-5",
145
+ torch_dtype=torch.float32
 
 
146
  ).to("cpu")
147
  model_cache[model_name] = pipe
 
148
  return pipe
149
 
150
+ # Initialize CLIP and default model
151
  print("🚀 Initializing Storybook Generator...")
152
+ initialize_clip()
153
+ current_pipe = load_model("sdxl")
154
+ print("✅ Models loaded and ready!")
155
 
156
+ # ADVANCED LONG PROMPT HANDLING
157
+ def segment_long_prompt(long_prompt, max_tokens=75):
158
+ """
159
+ Split long prompt into meaningful segments using CLIP tokenization
160
+ and semantic analysis
161
+ """
162
+ if clip_tokenizer is None:
163
+ # Fallback: simple sentence splitting
164
+ sentences = [s.strip() for s in long_prompt.split('.') if s.strip()]
165
+ return sentences
166
+
167
+ # Tokenize with CLIP to understand semantic boundaries
168
+ tokens = clip_tokenizer(long_prompt, return_tensors="pt", truncation=False)
169
+ token_count = tokens.input_ids.shape[1]
170
+
171
+ if token_count <= max_tokens:
172
+ return [long_prompt]
173
+
174
+ print(f"📝 Segmenting very long prompt: {token_count} tokens")
175
+
176
+ # Split into sentences first
177
+ sentences = [s.strip() for s in long_prompt.split('.') if s.strip()]
178
+ segments = []
179
+ current_segment = ""
180
+
181
+ for sentence in sentences:
182
+ test_segment = current_segment + ". " + sentence if current_segment else sentence
183
+ test_tokens = clip_tokenizer(test_segment, return_tensors="pt", truncation=False)
184
+
185
+ if test_tokens.input_ids.shape[1] <= max_tokens:
186
+ current_segment = test_segment
187
+ else:
188
+ if current_segment:
189
+ segments.append(current_segment)
190
+ current_segment = sentence
191
+
192
+ if current_segment:
193
+ segments.append(current_segment)
194
+
195
+ return segments
196
+
197
+ def create_prompt_hierarchy(full_prompt):
198
+ """
199
+ Create a hierarchical prompt structure with main focus and supporting details
200
+ """
201
+ segments = segment_long_prompt(full_prompt)
202
+
203
+ if len(segments) == 1:
204
+ return full_prompt
205
+
206
+ # The first segment is most important (main subject/action)
207
+ main_prompt = segments[0]
208
+
209
+ # Remaining segments become supporting context with weights
210
+ supporting_context = ""
211
+ for i, segment in enumerate(segments[1:], 1):
212
+ weight = 1.3 - (i * 0.1) # Decreasing weight for later segments
213
+ weight = max(0.8, min(1.5, weight))
214
+ supporting_context += f" ({segment}:{weight:.1f})"
215
+
216
+ final_prompt = f"{main_prompt}.{supporting_context}. masterpiece, best quality, 4K"
217
+ return final_prompt
218
+
219
+ def extract_key_phrases(prompt, max_phrases=10):
220
+ """
221
+ Extract the most important phrases from very long prompts
222
+ """
223
+ # Simple heuristic: nouns, adjectives, and verbs are important
224
+ words = prompt.split()
225
+ important_words = []
226
+
227
+ # Prioritize words after colons, in parentheses, or quoted
228
+ for i, word in enumerate(words):
229
+ if (':' in word or '(' in word or '[' in word or
230
+ word.isupper() or (i > 0 and words[i-1][-1] == ':')):
231
+ important_words.append(word)
232
+
233
+ # Also take first few words of each sentence
234
+ sentences = prompt.split('.')
235
+ for sentence in sentences:
236
+ first_words = sentence.strip().split()[:3]
237
+ important_words.extend(first_words)
238
+
239
+ # Remove duplicates and limit
240
+ important_words = list(set(important_words))[:max_phrases]
241
+ return " ".join(important_words)
242
+
243
+ def enhance_prompt(scene_visual, characters, style="childrens_book", page_number=1):
244
+ """Create comprehensive prompt with NO length limits"""
245
+
246
+ # Character context - include ALL details
247
+ character_context = ""
248
+ if characters:
249
+ char_descriptions = []
250
+ for char in characters:
251
+ if hasattr(char, 'description'):
252
+ char_descriptions.append(char.description)
253
+ elif isinstance(char, dict):
254
+ char_descriptions.append(char.get('description', ''))
255
+ character_context = " ".join(char_descriptions)
256
+ character_context = f"Character details: {character_context}."
257
+
258
+ # Scene continuity context
259
+ continuity_context = f"Scene {page_number}, " if page_number > 1 else ""
260
+
261
+ # Style templates
262
+ style_presets = {
263
+ "childrens_book": "children's book illustration, watercolor style, whimsical, charming, vibrant colors, soft lighting, storybook art, detailed backgrounds, cute characters, magical atmosphere",
264
+ "realistic": "photorealistic, professional photography, natural lighting, detailed, sharp focus, high resolution, realistic textures, studio quality, cinematic lighting",
265
+ "fantasy": "fantasy art, digital painting, magical, epic, concept art, dramatic lighting, mystical, otherworldly, detailed environments, heroic",
266
+ "anime": "anime style, Japanese animation, clean lines, vibrant colors, cel shading, detailed eyes, dynamic poses, manga style, professional animation"
267
  }
268
 
269
+ style_prompt = style_presets.get(style, style_presets["childrens_book"])
 
270
 
271
+ # Build COMPREHENSIVE prompt with ALL details
272
+ full_prompt = f"""
273
+ {continuity_context}
274
+ {scene_visual}.
275
+ {character_context}
276
+ Art style: {style_prompt}.
277
+ Technical quality: masterpiece, best quality, 4K resolution, ultra detailed,
278
+ professional artwork, award winning, trending on artstation, perfect composition,
279
+ ideal lighting, beautiful colors, no errors, perfect anatomy, consistent style
280
+ """
281
 
282
+ # Clean up the prompt
283
+ full_prompt = ' '.join(full_prompt.split()) # Remove extra whitespace
 
 
284
 
285
+ print(f"📝 Raw prompt length: {len(full_prompt.split())} words")
 
286
 
287
+ # Use hierarchical prompt creation for very long prompts
288
+ if len(full_prompt.split()) > 100:
289
+ optimized_prompt = create_prompt_hierarchy(full_prompt)
290
+ else:
291
+ optimized_prompt = full_prompt
292
+
293
+ print(f"📝 Final prompt length: {len(optimized_prompt.split())} words")
294
+
295
+ # Negative prompt
296
  negative_prompt = (
297
+ "blurry, low quality, ugly, deformed, poorly drawn, bad anatomy, "
298
+ "wrong anatomy, extra limb, missing limb, floating limbs, "
299
+ "disconnected limbs, mutation, mutated, disgusting, bad art, "
300
+ "beginner, amateur, distorted, watermark, signature, text, username, "
301
+ "multiple people, crowd, group, different characters, inconsistent features, "
302
+ "changed appearance, different face, altered features, low resolution, "
303
+ "jpeg artifacts, compression artifacts, noise, grain, out of focus"
304
  )
305
 
306
+ return optimized_prompt, negative_prompt
307
 
308
  def save_complete_storybook_page(image, story_title, sequence_number, scene_text):
 
309
  try:
 
310
  img_bytes = io.BytesIO()
311
  image.save(img_bytes, format='PNG')
312
  img_data = img_bytes.getvalue()
313
 
 
314
  clean_title = re.sub(r'[^a-zA-Z0-9_\-]', '', story_title.strip().replace(' ', '_'))
 
 
315
  image_filename = f"page_{sequence_number:03d}_{clean_title}.png"
316
  text_filename = f"page_{sequence_number:03d}_{clean_title}.txt"
317
 
 
336
  except Exception as e:
337
  return f"❌ Save failed: {str(e)}"
338
 
339
+ def get_character_seed(story_title, character_name, page_number):
 
 
 
 
 
 
 
 
340
  if story_title not in character_seeds:
341
  character_seeds[story_title] = {}
342
 
343
+ seed_key = f"{character_name}_{page_number}"
344
+ if seed_key not in character_seeds[story_title]:
345
+ base_seed = hash(f"{story_title}_{character_name}") % 1000000
346
+ page_variation = (page_number * 13) % 1000
347
+ seed_value = (base_seed + page_variation) % 1000000
348
+ character_seeds[story_title][seed_key] = seed_value
349
 
350
+ return character_seeds[story_title][seed_key]
351
 
352
+ def generate_storybook_page(scene_visual, story_title, sequence_number, scene_text, characters, model_choice="sdxl", style="childrens_book"):
 
353
  global current_pipe, current_model_name
354
 
355
  try:
 
356
  if model_choice != current_model_name:
357
+ current_pipe = load_model(model_choice)
 
 
 
 
358
 
359
+ enhanced_prompt, negative_prompt = enhance_prompt(
360
+ scene_visual, characters, style, sequence_number
361
+ )
362
 
363
+ print(f"📖 Generating page {sequence_number}")
364
+ print(f"📝 Prompt preview: {enhanced_prompt[:150]}...")
365
 
 
366
  if characters:
367
+ char_names = []
368
+ for char in characters:
369
+ if hasattr(char, 'name'):
370
+ char_names.append(char.name)
371
+ elif isinstance(char, dict):
372
+ char_names.append(char.get('name', 'unknown'))
373
+ print(f"👤 Characters: {char_names}")
374
 
 
375
  generator = torch.Generator(device="cpu")
376
+
377
  if characters:
378
+ first_char = characters[0]
379
+ char_name = first_char.name if hasattr(first_char, 'name') else first_char.get('name', 'unknown')
380
+ main_char_seed = get_character_seed(story_title, char_name, sequence_number)
381
  generator.manual_seed(main_char_seed)
 
382
  else:
383
+ scene_seed = hash(f"{story_title}_{sequence_number}") % 1000000
384
+ generator.manual_seed(scene_seed)
 
385
 
386
+ # Generate with SDXL which handles long prompts better
387
  image = current_pipe(
388
  prompt=enhanced_prompt,
389
  negative_prompt=negative_prompt,
390
+ num_inference_steps=40, # More steps for better detail
391
+ guidance_scale=7.0,
392
  width=768,
393
  height=768,
394
  generator=generator
395
  ).images[0]
396
 
 
397
  save_status = save_complete_storybook_page(image, story_title, sequence_number, scene_text)
 
398
  return image, save_status
399
 
400
  except Exception as e:
401
  return None, f"❌ Generation failed: {str(e)}"
402
 
403
+ def batch_generate_complete_storybook(story_title, scenes_data, characters, model_choice="sdxl", style="childrens_book"):
 
404
  global character_descriptions, current_pipe
405
 
406
  results = []
407
  status_messages = []
408
 
409
+ print(f"📚 Starting batch generation: {story_title}")
410
+ print(f"📖 Pages: {len(scenes_data)}")
411
  print(f"👤 Characters: {len(characters)}")
 
412
 
 
 
 
 
413
  if characters:
414
  character_descriptions[story_title] = characters
 
415
 
 
 
416
  current_pipe = load_model(model_choice)
 
417
  start_time = time.time()
418
 
419
  for i, scene_data in enumerate(scenes_data, 1):
420
  try:
 
421
  if i % 2 == 0:
422
  cleanup_memory()
 
423
 
424
  scene_visual = scene_data.get('visual', '')
425
  scene_text = scene_data.get('text', '')
 
433
  results.append((f"Page {i}", image, scene_text))
434
  status_messages.append(f"Page {i}: {status}")
435
 
436
+ if i < len(scenes_data):
437
+ time.sleep(2)
438
+
439
  except Exception as e:
440
  error_msg = f"❌ Failed page {i}: {str(e)}"
441
  print(error_msg)
442
  status_messages.append(error_msg)
 
443
 
444
  total_time = time.time() - start_time
445
+ print(f"✅ Batch completed in {total_time:.2f} seconds")
 
 
 
446
 
447
  return results, "\n".join(status_messages)
448
 
449
+ # FastAPI endpoint
450
  @app.post("/api/generate-storybook")
451
  async def api_generate_storybook(request: StorybookRequest):
 
452
  try:
453
+ print(f"📚 Received request: {request.story_title}")
454
+ print(f"📖 Pages: {len(request.scenes)}")
 
 
 
 
 
455
 
456
  start_time = time.time()
 
 
457
  scenes_data = [{"visual": scene.visual, "text": scene.text} for scene in request.scenes]
458
+ characters_dict = [char.dict() for char in request.characters]
459
 
 
460
  results, status = batch_generate_complete_storybook(
461
  request.story_title,
462
  scenes_data,
463
+ characters_dict,
464
  request.model_choice,
465
  request.style
466
  )
 
488
  except Exception as e:
489
  error_msg = f"Storybook generation failed: {str(e)}"
490
  print(f"❌ {error_msg}")
 
 
491
  raise HTTPException(status_code=500, detail=error_msg)
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  @app.get("/api/health")
494
  async def health_check():
 
495
  return {
496
  "status": "healthy",
497
  "service": "Storybook Generator API",
498
  "timestamp": datetime.now().isoformat(),
499
+ "memory_usage_mb": monitor_memory(),
500
  "models_loaded": list(model_cache.keys()),
501
+ "current_model": current_model_name
 
 
502
  }
503
 
504
+ # Gradio Interface
505
  def generate_single_page(prompt, story_title, scene_text, model_choice, style):
 
506
  if not prompt or not story_title:
507
  return None, "❌ Please enter both scene description and story title"
508
 
 
509
  global current_pipe
510
  if current_model_name != model_choice:
511
  current_pipe = load_model(model_choice)
 
515
  )
516
  return image, status
517
 
 
518
  with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
519
  gr.Markdown("# 📚 Storybook Generator")
520
  gr.Markdown("Create beautiful storybooks with consistent characters")
521
 
522
  with gr.Row():
523
  with gr.Column(scale=1):
524
+ story_title_input = gr.Textbox(label="Story Title", lines=1)
 
 
 
 
 
525
  model_choice = gr.Dropdown(
526
  label="AI Model",
527
  choices=list(MODEL_CHOICES.keys()),
528
+ value="sdxl"
529
  )
 
530
  style_choice = gr.Dropdown(
531
  label="Art Style",
532
  choices=["childrens_book", "realistic", "fantasy", "anime"],
 
534
  )
535
 
536
  with gr.Column(scale=2):
537
+ prompt_input = gr.Textbox(label="Visual Description", lines=5)
538
+ text_input = gr.Textbox(label="Story Text (Optional)", lines=2)
 
 
 
 
 
 
 
 
 
 
539
  generate_btn = gr.Button("✨ Generate Single Page", variant="primary")
540
  image_output = gr.Image(label="Generated Page", height=400)
541
  status_output = gr.Textbox(label="Status", interactive=False)
 
546
  outputs=[image_output, status_output]
547
  )
548
 
 
549
  app = gr.mount_gradio_app(app, demo, path="/")
550
 
 
 
 
 
551
  if __name__ == "__main__":
552
  print("🚀 Starting Storybook Generator API...")
 
 
 
 
 
 
 
553
  import uvicorn
554
  uvicorn.run(app, host="0.0.0.0", port=7860)