yukee1992 commited on
Commit
c68efab
Β·
verified Β·
1 Parent(s): 1d24354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -141
app.py CHANGED
@@ -8,7 +8,8 @@ import os
8
  from datetime import datetime
9
  import re
10
  import time
11
- from typing import List, Optional
 
12
  from fastapi import FastAPI, HTTPException
13
  from pydantic import BaseModel
14
  import gc
@@ -57,6 +58,7 @@ class StorybookResponse(BaseModel):
57
  message: str
58
  folder_path: str
59
  pages: List[dict]
 
60
 
61
  # MODEL SELECTION
62
  MODEL_CHOICES = {
@@ -72,6 +74,7 @@ model_lock = threading.Lock()
72
 
73
  # Character consistency tracking
74
  character_seeds = {}
 
75
 
76
  def monitor_memory():
77
  try:
@@ -125,70 +128,62 @@ print("πŸš€ Initializing Storybook Generator...")
125
  load_model("dreamshaper-8")
126
  print("βœ… Model loaded and ready!")
127
 
128
- # PROMPT OPTIMIZATION
129
  def optimize_prompt(scene_visual, characters, style="childrens_book", page_number=1):
130
  """
131
- Create a prompt that FITS within 77 tokens
132
  """
133
- # Extract character essence
 
 
 
134
  character_essence = ""
135
  if characters:
136
- char_descriptors = []
137
  for char in characters:
138
- desc = char.get('description', '') if isinstance(char, dict) else getattr(char, 'description', '')
139
-
140
- import re
141
- species_match = re.search(r'(rabbit|hedgehog|bird|dog|cat|fox|bear|dragon|human|girl|boy)', desc, re.IGNORECASE)
142
- species = species_match.group(1) if species_match else "character"
143
-
144
- color_match = re.search(r'(white|black|brown|blue|red|green|yellow|golden|pink)', desc, re.IGNORECASE)
145
- color = color_match.group(1) if color_match else ""
146
-
147
- key_feature = ""
148
- if 'glasses' in desc.lower(): key_feature = "with glasses"
149
- elif 'dress' in desc.lower(): key_feature = "in dress"
150
- elif 'hat' in desc.lower(): key_feature = "with hat"
151
-
152
- char_descriptors.append(f"{color} {species} {key_feature}".strip())
153
 
154
- character_essence = f"Characters: {', '.join(char_descriptors)}. "
155
-
156
- # Compress scene description
157
- scene_words = scene_visual.split()
158
- if len(scene_words) > 30:
159
- scene_compressed = ' '.join(scene_words[:30])
160
- else:
161
- scene_compressed = scene_visual
162
 
163
- # Style context
164
  style_context = {
165
- "childrens_book": "children's book illustration",
166
- "realistic": "photorealistic",
167
- "fantasy": "fantasy art",
168
  "anime": "anime style"
169
- }.get(style, "children's book illustration")
170
 
171
- # Build final prompt
172
- continuity = f"Scene {page_number}: " if page_number > 1 else ""
173
- final_prompt = f"{continuity}{scene_compressed}. {character_essence}{style_context}. masterpiece, best quality"
174
 
175
- # Ensure it's under 55 words for safety
176
  words = final_prompt.split()
177
- if len(words) > 55:
178
- final_prompt = ' '.join(words[:55])
 
 
 
 
 
 
 
 
179
 
180
- print(f"πŸ“ Optimized prompt: {final_prompt}")
181
  print(f"πŸ“ Length: {len(final_prompt.split())} words")
182
 
183
  return final_prompt
184
 
185
  def enhance_prompt(scene_visual, characters, style="childrens_book", page_number=1):
186
- """Create optimized prompt"""
187
  main_prompt = optimize_prompt(scene_visual, characters, style, page_number)
188
 
189
  negative_prompt = (
190
  "blurry, low quality, ugly, deformed, bad anatomy, "
191
- "watermark, text, username, multiple people, inconsistent"
 
192
  )
193
 
194
  return main_prompt, negative_prompt
@@ -237,120 +232,146 @@ def get_character_seed(story_title, character_name, page_number):
237
 
238
  return character_seeds[story_title][seed_key]
239
 
240
- def generate_single_page(scene_visual, scene_text, story_title, sequence_number, characters, model_choice, style):
241
- """Generate a single page"""
242
- try:
243
- print(f"πŸ”„ Generating page {sequence_number}...")
244
-
245
- enhanced_prompt, negative_prompt = enhance_prompt(
246
- scene_visual, characters, style, sequence_number
247
- )
248
-
249
- # Get character name for seed
250
- main_char_name = "default"
251
- if characters:
252
- first_char = characters[0]
253
- main_char_name = first_char.get('name', 'default') if isinstance(first_char, dict) else getattr(first_char, 'name', 'default')
254
-
255
- # Use consistent seed
256
- generator = torch.Generator(device="cpu")
257
- main_char_seed = get_character_seed(story_title, main_char_name, sequence_number)
258
- generator.manual_seed(main_char_seed)
259
-
260
- # Generate image
261
- global current_pipe
262
- image = current_pipe(
263
- prompt=enhanced_prompt,
264
- negative_prompt=negative_prompt,
265
- num_inference_steps=20,
266
- guidance_scale=7.0,
267
- width=512,
268
- height=512,
269
- generator=generator
270
- ).images[0]
271
-
272
- # Save to OCI
273
- success, save_status = save_complete_storybook_page(image, story_title, sequence_number, scene_text)
274
-
275
- if success:
276
- print(f"βœ… Page {sequence_number} completed successfully")
277
- return True, save_status
278
- else:
279
- print(f"❌ Page {sequence_number} save failed: {save_status}")
280
- return False, save_status
281
-
282
- except Exception as e:
283
- error_msg = f"❌ Page {sequence_number} generation failed: {str(e)}"
284
- print(error_msg)
285
- return False, error_msg
286
-
287
- # FastAPI endpoint - SYNCHRONOUS VERSION
288
- @app.post("/api/generate-storybook", response_model=StorybookResponse)
289
- async def api_generate_storybook(request: StorybookRequest):
290
- """Synchronous API endpoint that actually works on Hugging Face"""
291
  try:
292
- print(f"πŸ“š Received request: {request.story_title}")
293
- print(f"πŸ“– Pages: {len(request.scenes)}")
294
- print(f"πŸ‘€ Characters: {len(request.characters)}")
295
-
296
- start_time = time.time()
297
 
298
- # Load model ONCE
299
- load_model(request.model_choice)
300
 
301
  # Convert characters to dict
302
  characters_dict = []
303
- for char in request.characters:
304
  characters_dict.append({
305
- "name": char.name,
306
- "description": char.description
307
  })
308
 
309
- generated_count = 0
310
  status_messages = []
 
311
 
312
- # Process each page SEQUENTIALLY
313
- for i, scene in enumerate(request.scenes, 1):
314
  try:
315
- success, message = generate_single_page(
316
- scene.visual,
317
- scene.text,
318
- request.story_title,
319
- i,
320
- characters_dict,
321
- request.model_choice,
322
- request.style
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  )
324
 
325
  if success:
326
- generated_count += 1
327
- status_messages.append(f"Page {i}: {message}")
 
328
  else:
329
- status_messages.append(f"Page {i}: {message}")
 
330
 
331
- # Clean memory after each page
332
  cleanup_memory()
333
 
334
- # Add small delay between pages
335
- if i < len(request.scenes):
336
  time.sleep(1)
337
 
338
  except Exception as e:
339
  error_msg = f"Page {i} failed: {str(e)}"
 
340
  status_messages.append(error_msg)
341
  print(f"❌ {error_msg}")
342
 
343
  total_time = time.time() - start_time
344
 
345
- # Create response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  response_data = {
347
- "status": "success" if generated_count > 0 else "partial",
348
  "story_title": request.story_title,
349
  "total_pages": len(request.scenes),
350
  "characters_used": len(request.characters),
351
- "generated_pages": generated_count,
352
- "generation_time": round(total_time, 2),
353
- "message": "\n".join(status_messages),
354
  "folder_path": f"storybook-library/stories/{request.story_title.replace(' ', '_')}/",
355
  "pages": [
356
  {
@@ -358,19 +379,33 @@ async def api_generate_storybook(request: StorybookRequest):
358
  "image_file": f"page_{i+1:03d}_{request.story_title.replace(' ', '_')}.png",
359
  "text_file": f"page_{i+1:03d}_{request.story_title.replace(' ', '_')}.txt"
360
  } for i in range(len(request.scenes))
361
- ]
 
362
  }
363
 
364
- print(f"βœ… Generation completed in {total_time:.2f} seconds")
365
- print(f"πŸ“Š Generated {generated_count}/{len(request.scenes)} pages")
366
-
367
  return response_data
368
 
369
  except Exception as e:
370
- error_msg = f"Storybook generation failed: {str(e)}"
371
  print(f"❌ {error_msg}")
372
  raise HTTPException(status_code=500, detail=error_msg)
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  @app.get("/api/health")
375
  async def health_check():
376
  return {
@@ -379,7 +414,8 @@ async def health_check():
379
  "timestamp": datetime.now().isoformat(),
380
  "memory_usage_mb": monitor_memory(),
381
  "models_loaded": list(model_cache.keys()),
382
- "current_model": current_model_name
 
383
  }
384
 
385
  # Simple Gradio interface
@@ -388,21 +424,32 @@ with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
388
 
389
  with gr.Row():
390
  story_title = gr.Textbox(label="Story Title", value="Test Story")
391
- prompt_input = gr.Textbox(label="Scene Description", lines=3, value="A beautiful sunset over mountains")
392
  generate_btn = gr.Button("Generate Test Page")
393
  output_image = gr.Image()
394
  status = gr.Textbox()
395
 
396
  def generate_test_page(prompt, title):
397
  try:
398
- success, message = generate_single_page(
399
- prompt, "", title, 1, [], "dreamshaper-8", "childrens_book"
400
- )
401
- if success:
402
- # For demo, return a placeholder since we can't easily get the image back
403
- return None, message
404
- else:
405
- return None, message
 
 
 
 
 
 
 
 
 
 
 
406
  except Exception as e:
407
  return None, f"Error: {str(e)}"
408
 
@@ -412,7 +459,7 @@ with gr.Blocks(title="Storybook Generator", theme="soft") as demo:
412
  outputs=[output_image, status]
413
  )
414
 
415
- app = gr.mount_gradio_app(app, demo, path="/")
416
 
417
  if __name__ == "__main__":
418
  print("πŸš€ Starting Storybook Generator API...")
 
8
  from datetime import datetime
9
  import re
10
  import time
11
+ import json
12
+ from typing import List, Optional, Dict
13
  from fastapi import FastAPI, HTTPException
14
  from pydantic import BaseModel
15
  import gc
 
58
  message: str
59
  folder_path: str
60
  pages: List[dict]
61
+ request_id: str
62
 
63
  # MODEL SELECTION
64
  MODEL_CHOICES = {
 
74
 
75
  # Character consistency tracking
76
  character_seeds = {}
77
+ active_requests = {}
78
 
79
  def monitor_memory():
80
  try:
 
128
  load_model("dreamshaper-8")
129
  print("βœ… Model loaded and ready!")
130
 
131
+ # PROMPT OPTIMIZATION - PRESERVE FULL DESCRIPTIONS
132
  def optimize_prompt(scene_visual, characters, style="childrens_book", page_number=1):
133
  """
134
+ Create a prompt that PRESERVES all visual descriptions while fitting 77 tokens
135
  """
136
+ # 1. PRESERVE THE ENTIRE SCENE VISUAL DESCRIPTION (most important)
137
+ scene_prompt = scene_visual
138
+
139
+ # 2. Extract only ESSENTIAL character features (not full descriptions)
140
  character_essence = ""
141
  if characters:
142
+ char_names = []
143
  for char in characters:
144
+ char_name = char.get('name', '') if isinstance(char, dict) else getattr(char, 'name', '')
145
+ char_names.append(char_name.split()[0]) # Just first name
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ character_essence = f" featuring {', '.join(char_names)}"
 
 
 
 
 
 
 
148
 
149
+ # 3. Add style context briefly
150
  style_context = {
151
+ "childrens_book": "children's book illustration style",
152
+ "realistic": "photorealistic style",
153
+ "fantasy": "fantasy art style",
154
  "anime": "anime style"
155
+ }.get(style, "children's book illustration style")
156
 
157
+ # 4. Build the final prompt - SCENE DESCRIPTION COMES FIRST
158
+ continuity = f"Scene {page_number}, " if page_number > 1 else ""
159
+ final_prompt = f"{continuity}{scene_prompt}{character_essence}. {style_context}. high quality, detailed"
160
 
161
+ # 5. If still too long, prioritize scene description over style
162
  words = final_prompt.split()
163
+ if len(words) > 60:
164
+ # Keep the scene description intact, trim the end
165
+ scene_words = scene_visual.split()
166
+ if len(scene_words) > 45:
167
+ # If scene itself is too long, keep first 40 words of scene
168
+ scene_part = ' '.join(scene_words[:40])
169
+ final_prompt = f"{continuity}{scene_part}...{character_essence}. {style_context}"
170
+ else:
171
+ # Keep entire scene, trim style part
172
+ final_prompt = f"{continuity}{scene_visual}{character_essence}. high quality"
173
 
174
+ print(f"πŸ“ Final prompt: {final_prompt}")
175
  print(f"πŸ“ Length: {len(final_prompt.split())} words")
176
 
177
  return final_prompt
178
 
179
  def enhance_prompt(scene_visual, characters, style="childrens_book", page_number=1):
180
+ """Create optimized prompt that preserves visual descriptions"""
181
  main_prompt = optimize_prompt(scene_visual, characters, style, page_number)
182
 
183
  negative_prompt = (
184
  "blurry, low quality, ugly, deformed, bad anatomy, "
185
+ "watermark, text, username, multiple people, inconsistent, "
186
+ "missing limbs, extra limbs, disfigured, malformed"
187
  )
188
 
189
  return main_prompt, negative_prompt
 
232
 
233
  return character_seeds[story_title][seed_key]
234
 
235
+ def process_storybook_generation(request_id, request_data):
236
+ """Process generation in background and store results"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  try:
238
+ print(f"πŸ”§ Processing request {request_id} in background...")
 
 
 
 
239
 
240
+ # Load model
241
+ load_model(request_data["model_choice"])
242
 
243
  # Convert characters to dict
244
  characters_dict = []
245
+ for char in request_data["characters"]:
246
  characters_dict.append({
247
+ "name": char["name"],
248
+ "description": char["description"]
249
  })
250
 
251
+ results = []
252
  status_messages = []
253
+ start_time = time.time()
254
 
255
+ # Process each page
256
+ for i, scene in enumerate(request_data["scenes"], 1):
257
  try:
258
+ print(f"πŸ”„ Generating page {i}...")
259
+
260
+ enhanced_prompt, negative_prompt = enhance_prompt(
261
+ scene["visual"], characters_dict, request_data["style"], i
262
+ )
263
+
264
+ # Get character name for seed
265
+ main_char_name = "default"
266
+ if characters_dict:
267
+ main_char_name = characters_dict[0]["name"]
268
+
269
+ # Use consistent seed
270
+ generator = torch.Generator(device="cpu")
271
+ main_char_seed = get_character_seed(request_data["story_title"], main_char_name, i)
272
+ generator.manual_seed(main_char_seed)
273
+
274
+ # Generate image
275
+ global current_pipe
276
+ image = current_pipe(
277
+ prompt=enhanced_prompt,
278
+ negative_prompt=negative_prompt,
279
+ num_inference_steps=25,
280
+ guidance_scale=7.0,
281
+ width=512,
282
+ height=512,
283
+ generator=generator
284
+ ).images[0]
285
+
286
+ # Save to OCI
287
+ success, save_status = save_complete_storybook_page(
288
+ image, request_data["story_title"], i, scene["text"]
289
  )
290
 
291
  if success:
292
+ results.append({"page_number": i, "status": "success"})
293
+ status_messages.append(f"Page {i}: {save_status}")
294
+ print(f"βœ… Page {i} completed")
295
  else:
296
+ results.append({"page_number": i, "status": "error", "message": save_status})
297
+ status_messages.append(f"Page {i}: {save_status}")
298
 
 
299
  cleanup_memory()
300
 
301
+ if i < len(request_data["scenes"]):
 
302
  time.sleep(1)
303
 
304
  except Exception as e:
305
  error_msg = f"Page {i} failed: {str(e)}"
306
+ results.append({"page_number": i, "status": "error", "message": error_msg})
307
  status_messages.append(error_msg)
308
  print(f"❌ {error_msg}")
309
 
310
  total_time = time.time() - start_time
311
 
312
+ # Store results
313
+ active_requests[request_id] = {
314
+ "status": "completed",
315
+ "results": results,
316
+ "message": "\n".join(status_messages),
317
+ "generation_time": total_time,
318
+ "completed_at": datetime.now().isoformat()
319
+ }
320
+
321
+ print(f"βœ… Request {request_id} completed in {total_time:.2f} seconds")
322
+
323
+ except Exception as e:
324
+ active_requests[request_id] = {
325
+ "status": "error",
326
+ "message": f"Processing failed: {str(e)}"
327
+ }
328
+ print(f"❌ Request {request_id} failed: {e}")
329
+
330
+ # FastAPI endpoint - IMMEDIATE RESPONSE
331
+ @app.post("/api/generate-storybook", response_model=StorybookResponse)
332
+ async def api_generate_storybook(request: StorybookRequest):
333
+ """API endpoint that returns immediately"""
334
+ try:
335
+ print(f"πŸ“š Received request: {request.story_title}")
336
+ print(f"πŸ“– Pages: {len(request.scenes)}")
337
+
338
+ # Create request ID
339
+ request_id = f"{request.story_title}_{int(time.time())}"
340
+
341
+ # Convert to dict for background processing
342
+ request_data = {
343
+ "story_title": request.story_title,
344
+ "scenes": [{"visual": scene.visual, "text": scene.text} for scene in request.scenes],
345
+ "characters": [{"name": char.name, "description": char.description} for char in request.characters],
346
+ "model_choice": request.model_choice,
347
+ "style": request.style
348
+ }
349
+
350
+ # Store initial request state
351
+ active_requests[request_id] = {
352
+ "status": "processing",
353
+ "started_at": datetime.now().isoformat(),
354
+ "total_pages": len(request.scenes)
355
+ }
356
+
357
+ # Start background processing in a thread
358
+ import threading
359
+ thread = threading.Thread(
360
+ target=process_storybook_generation,
361
+ args=(request_id, request_data)
362
+ )
363
+ thread.daemon = True
364
+ thread.start()
365
+
366
+ # IMMEDIATE RESPONSE to n8n
367
  response_data = {
368
+ "status": "processing",
369
  "story_title": request.story_title,
370
  "total_pages": len(request.scenes),
371
  "characters_used": len(request.characters),
372
+ "generated_pages": 0,
373
+ "generation_time": 0,
374
+ "message": f"Generation started for {len(request.scenes)} pages. Request ID: {request_id}",
375
  "folder_path": f"storybook-library/stories/{request.story_title.replace(' ', '_')}/",
376
  "pages": [
377
  {
 
379
  "image_file": f"page_{i+1:03d}_{request.story_title.replace(' ', '_')}.png",
380
  "text_file": f"page_{i+1:03d}_{request.story_title.replace(' ', '_')}.txt"
381
  } for i in range(len(request.scenes))
382
+ ],
383
+ "request_id": request_id
384
  }
385
 
 
 
 
386
  return response_data
387
 
388
  except Exception as e:
389
+ error_msg = f"Request failed: {str(e)}"
390
  print(f"❌ {error_msg}")
391
  raise HTTPException(status_code=500, detail=error_msg)
392
 
393
+ # Status check endpoint for n8n
394
+ @app.get("/api/status/{request_id}")
395
+ async def check_status(request_id: str):
396
+ """Check status of a generation request"""
397
+ if request_id not in active_requests:
398
+ return {"status": "not_found", "message": "Request ID not found"}
399
+
400
+ request_data = active_requests[request_id]
401
+ return {
402
+ "status": request_data["status"],
403
+ "message": request_data.get("message", ""),
404
+ "generation_time": request_data.get("generation_time", 0),
405
+ "completed_at": request_data.get("completed_at", ""),
406
+ "total_pages": request_data.get("total_pages", 0)
407
+ }
408
+
409
  @app.get("/api/health")
410
  async def health_check():
411
  return {
 
414
  "timestamp": datetime.now().isoformat(),
415
  "memory_usage_mb": monitor_memory(),
416
  "models_loaded": list(model_cache.keys()),
417
+ "current_model": current_model_name,
418
+ "active_requests": len(active_requests)
419
  }
420
 
421
  # Simple Gradio interface
 
424
 
425
  with gr.Row():
426
  story_title = gr.Textbox(label="Story Title", value="Test Story")
427
+ prompt_input = gr.Textbox(label="Scene Description", lines=3, value="A beautiful sunset over mountains with vibrant colors")
428
  generate_btn = gr.Button("Generate Test Page")
429
  output_image = gr.Image()
430
  status = gr.Textbox()
431
 
432
  def generate_test_page(prompt, title):
433
  try:
434
+ # Test with a simple generation
435
+ enhanced_prompt, negative_prompt = enhance_prompt(prompt, [], "childrens_book", 1)
436
+
437
+ generator = torch.Generator(device="cpu")
438
+ generator.manual_seed(123)
439
+
440
+ global current_pipe
441
+ image = current_pipe(
442
+ prompt=enhanced_prompt,
443
+ negative_prompt=negative_prompt,
444
+ num_inference_steps=20,
445
+ guidance_scale=7.0,
446
+ width=512,
447
+ height=512,
448
+ generator=generator
449
+ ).images[0]
450
+
451
+ return image, f"βœ… Generated: {enhanced_prompt}"
452
+
453
  except Exception as e:
454
  return None, f"Error: {str(e)}"
455
 
 
459
  outputs=[output_image, status]
460
  )
461
 
462
+ app = gr.mount_grado_app(app, demo, path="/")
463
 
464
  if __name__ == "__main__":
465
  print("πŸš€ Starting Storybook Generator API...")