yukee1992 commited on
Commit
8412e63
·
verified ·
1 Parent(s): 1fe624c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -710
app.py CHANGED
@@ -22,10 +22,24 @@ import random
22
  import time
23
  from requests.adapters import HTTPAdapter
24
  from urllib3.util.retry import Retry
25
- from huggingface_hub import HfApi # NEW: Add this import
 
26
 
27
  # =============================================
28
- # HUGGING FACE DATASET CONFIGURATION (NEW)
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # =============================================
30
  HF_TOKEN = os.environ.get("HF_TOKEN")
31
  HF_USERNAME = "yukee1992"
@@ -77,7 +91,7 @@ class StorybookRequest(BaseModel):
77
  style: str = "childrens_book"
78
  callback_url: Optional[str] = None
79
  consistency_seed: Optional[int] = None
80
- project_id: Optional[str] = None # ADDED for HF Dataset organization
81
 
82
  class JobStatusResponse(BaseModel):
83
  job_id: str
@@ -104,7 +118,7 @@ class MemoryStatusResponse(BaseModel):
104
  gpu_memory_cached_mb: Optional[float] = None
105
  status: str
106
 
107
- # HIGH-QUALITY MODEL SELECTION - ANIME FOCUSED & WORKING
108
  MODEL_CHOICES = {
109
  "dreamshaper-8": "lykon/dreamshaper-8",
110
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
@@ -132,7 +146,6 @@ def get_memory_usage():
132
  memory_used_mb = memory_info.rss / (1024 * 1024)
133
  memory_percent = process.memory_percent()
134
 
135
- # GPU memory if available
136
  gpu_memory_allocated_mb = None
137
  gpu_memory_cached_mb = None
138
 
@@ -154,17 +167,13 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
154
  """Clear memory by unloading models and cleaning up resources"""
155
  results = []
156
 
157
- # Clear model cache
158
  if clear_models:
159
  with model_lock:
160
  models_cleared = len(model_cache)
161
  for model_name, pipe in model_cache.items():
162
  try:
163
- # Move to CPU first if it's on GPU
164
  if hasattr(pipe, 'to'):
165
  pipe.to('cpu')
166
-
167
- # Delete the pipeline
168
  del pipe
169
  results.append(f"Unloaded model: {model_name}")
170
  except Exception as e:
@@ -176,7 +185,6 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
176
  current_model_name = None
177
  results.append(f"Cleared {models_cleared} models from cache")
178
 
179
- # Clear completed jobs
180
  if clear_jobs:
181
  jobs_to_clear = []
182
  for job_id, job_data in job_storage.items():
@@ -189,7 +197,6 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
189
 
190
  results.append(f"Cleared {len(jobs_to_clear)} completed/failed jobs")
191
 
192
- # Clear local images
193
  if clear_local_images:
194
  try:
195
  storage_info = get_local_storage_info()
@@ -203,7 +210,6 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
203
  except Exception as e:
204
  results.append(f"Error clearing local images: {str(e)}")
205
 
206
- # Force garbage collection
207
  if force_gc:
208
  gc.collect()
209
  if torch.cuda.is_available():
@@ -212,7 +218,6 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
212
  results.append("GPU cache cleared")
213
  results.append("Garbage collection forced")
214
 
215
- # Get memory status after cleanup
216
  memory_status = get_memory_usage()
217
 
218
  return {
@@ -221,8 +226,11 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
221
  "memory_after_cleanup": memory_status
222
  }
223
 
 
 
 
224
  def load_model(model_name="dreamshaper-8"):
225
- """Thread-safe model loading with HIGH-QUALITY settings and better error handling"""
226
  global model_cache, current_model_name, current_pipe
227
 
228
  with model_lock:
@@ -231,53 +239,65 @@ def load_model(model_name="dreamshaper-8"):
231
  current_model_name = model_name
232
  return current_pipe
233
 
234
- print(f"🔄 Loading HIGH-QUALITY model: {model_name}")
235
  try:
236
  model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
237
 
238
- print(f"🔧 Attempting to load: {model_id}")
239
-
240
  pipe = StableDiffusionPipeline.from_pretrained(
241
  model_id,
242
  torch_dtype=torch.float32,
243
  safety_checker=None,
244
  requires_safety_checker=False,
245
- local_files_only=False, # Allow downloading if not cached
246
- cache_dir="./model_cache" # Specific cache directory
 
 
247
  )
248
 
 
249
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
250
  pipe = pipe.to("cpu")
251
 
252
  model_cache[model_name] = pipe
253
  current_pipe = pipe
254
  current_model_name = model_name
255
 
256
- print(f"✅ HIGH-QUALITY Model loaded: {model_name}")
257
  return pipe
258
 
259
  except Exception as e:
260
  print(f"❌ Model loading failed for {model_name}: {e}")
261
  print(f"🔄 Falling back to stable-diffusion-v1-5")
262
 
263
- # Fallback to base model
264
  try:
265
  pipe = StableDiffusionPipeline.from_pretrained(
266
  "runwayml/stable-diffusion-v1-5",
267
  torch_dtype=torch.float32,
268
  safety_checker=None,
269
- requires_safety_checker=False
 
270
  ).to("cpu")
271
 
 
 
272
  model_cache[model_name] = pipe
273
  current_pipe = pipe
274
  current_model_name = "sd-1.5"
275
 
276
- print(f"✅ Fallback model loaded: stable-diffusion-v1-5")
277
  return pipe
278
 
279
  except Exception as fallback_error:
280
- print(f"❌ Critical: Fallback model also failed: {fallback_error}")
281
  raise
282
 
283
  # Initialize default model
@@ -285,11 +305,87 @@ print("🚀 Initializing Storybook Generator API...")
285
  load_model("dreamshaper-8")
286
  print("✅ Model loaded and ready!")
287
 
288
- # SIMPLE PROMPT ENGINEERING - USE PURE PROMPTS ONLY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def enhance_prompt_simple(scene_visual, style="childrens_book"):
290
  """Simple prompt enhancement - uses only the provided visual prompt with style"""
291
 
292
- # Style templates
293
  style_templates = {
294
  "childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration",
295
  "realistic": "photorealistic, detailed, natural lighting, professional photography",
@@ -299,10 +395,8 @@ def enhance_prompt_simple(scene_visual, style="childrens_book"):
299
 
300
  style_prompt = style_templates.get(style, style_templates["childrens_book"])
301
 
302
- # Use only the provided visual prompt with style
303
  enhanced_prompt = f"{style_prompt}, {scene_visual}"
304
 
305
- # Basic negative prompt for quality
306
  negative_prompt = (
307
  "blurry, low quality, bad anatomy, deformed characters, "
308
  "wrong proportions, mismatched features"
@@ -310,13 +404,14 @@ def enhance_prompt_simple(scene_visual, style="childrens_book"):
310
 
311
  return enhanced_prompt, negative_prompt
312
 
 
 
 
313
  def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None):
314
- """Generate image using pure prompts only"""
315
 
316
- # Enhance prompt with simple style addition
317
  enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style)
318
 
319
- # Use seed if provided
320
  if consistency_seed:
321
  scene_seed = consistency_seed + scene_number
322
  else:
@@ -325,20 +420,23 @@ def generate_image_simple(prompt, model_choice, style, scene_number, consistency
325
  try:
326
  pipe = load_model(model_choice)
327
 
328
- image = pipe(
329
- prompt=enhanced_prompt,
330
- negative_prompt=negative_prompt,
331
- num_inference_steps=35,
332
- guidance_scale=7.5,
333
- width=768,
334
- height=1024, # Portrait for better full-body
335
- generator=torch.Generator(device="cpu").manual_seed(scene_seed)
336
- ).images[0]
 
 
 
 
 
 
337
 
338
  print(f"✅ Generated image for scene {scene_number}")
339
- print(f"🌱 Seed used: {scene_seed}")
340
- print(f"📝 Pure prompt used: {prompt}")
341
-
342
  return image
343
 
344
  except Exception as e:
@@ -353,12 +451,10 @@ def save_image_to_local(image, prompt, style="test"):
353
  safe_prompt = "".join(c for c in prompt[:50] if c.isalnum() or c in (' ', '-', '_')).rstrip()
354
  filename = f"image_{safe_prompt}_{timestamp}.png"
355
 
356
- # Create style subfolder
357
  style_dir = os.path.join(PERSISTENT_IMAGE_DIR, style)
358
  os.makedirs(style_dir, exist_ok=True)
359
  filepath = os.path.join(style_dir, filename)
360
 
361
- # Save the image
362
  image.save(filepath)
363
  print(f"💾 Image saved locally: {filepath}")
364
 
@@ -373,7 +469,6 @@ def delete_local_image(filepath):
373
  try:
374
  if os.path.exists(filepath):
375
  os.remove(filepath)
376
- print(f"🗑️ Deleted local image: {filepath}")
377
  return True, f"✅ Deleted: {os.path.basename(filepath)}"
378
  else:
379
  return False, f"❌ File not found: {filepath}"
@@ -425,89 +520,6 @@ def refresh_local_images():
425
  print(f"Error refreshing local images: {e}")
426
  return []
427
 
428
- # =============================================
429
- # NEW: HUGGING FACE DATASET FUNCTIONS
430
- # =============================================
431
-
432
- def ensure_dataset_exists():
433
- """Create dataset if it doesn't exist"""
434
- if not HF_TOKEN:
435
- print("⚠️ HF_TOKEN not set, cannot create/verify dataset")
436
- return False
437
-
438
- try:
439
- api = HfApi(token=HF_TOKEN)
440
- try:
441
- api.dataset_info(DATASET_ID)
442
- print(f"✅ Dataset {DATASET_ID} exists")
443
- except Exception:
444
- print(f"📦 Creating dataset: {DATASET_ID}")
445
- api.create_repo(
446
- repo_id=DATASET_ID,
447
- repo_type="dataset",
448
- private=False,
449
- exist_ok=True
450
- )
451
- print(f"✅ Created dataset: {DATASET_ID}")
452
- return True
453
- except Exception as e:
454
- print(f"❌ Failed to ensure dataset: {e}")
455
- return False
456
-
457
- def upload_to_hf_dataset(file_content, filename, subfolder=""):
458
- """Upload a file to Hugging Face Dataset"""
459
- if not HF_TOKEN:
460
- print("⚠️ HF_TOKEN not set, skipping upload")
461
- return None
462
-
463
- try:
464
- if subfolder:
465
- path_in_repo = f"data/{subfolder}/{filename}"
466
- else:
467
- path_in_repo = f"data/{filename}"
468
-
469
- api = HfApi(token=HF_TOKEN)
470
- api.upload_file(
471
- path_or_fileobj=file_content,
472
- path_in_repo=path_in_repo,
473
- repo_id=DATASET_ID,
474
- repo_type="dataset"
475
- )
476
-
477
- url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}"
478
- print(f"✅ Uploaded to HF Dataset: {url}")
479
- return url
480
-
481
- except Exception as e:
482
- print(f"❌ Failed to upload to HF Dataset: {e}")
483
- return None
484
-
485
- def upload_image_to_hf_dataset(image, project_id, page_number, prompt, style=""):
486
- """Upload generated image to HF Dataset"""
487
- try:
488
- img_bytes = io.BytesIO()
489
- image.save(img_bytes, format='PNG')
490
- img_data = img_bytes.getvalue()
491
-
492
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
493
- safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
494
- safe_prompt = safe_prompt.replace(' ', '_')
495
- filename = f"page_{page_number:03d}_{safe_prompt}_{timestamp}.png"
496
-
497
- subfolder = f"projects/{project_id}"
498
- url = upload_to_hf_dataset(img_data, filename, subfolder)
499
-
500
- return url
501
-
502
- except Exception as e:
503
- print(f"❌ Failed to upload image to HF Dataset: {e}")
504
- return None
505
-
506
- # =============================================
507
- # REMOVED: OCI BUCKET FUNCTIONS
508
- # (save_to_oci_bucket and test_oci_connection are removed)
509
- # =============================================
510
-
511
  # JOB MANAGEMENT FUNCTIONS
512
  def create_job(story_request: StorybookRequest) -> str:
513
  job_id = str(uuid.uuid4())
@@ -524,8 +536,6 @@ def create_job(story_request: StorybookRequest) -> str:
524
  }
525
 
526
  print(f"📝 Created job {job_id} for story: {story_request.story_title}")
527
- print(f"📄 Scenes to generate: {len(story_request.scenes)}")
528
-
529
  return job_id
530
 
531
  def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None):
@@ -542,66 +552,28 @@ def update_job_status(job_id: str, status: JobStatus, progress: int, message: st
542
  if result:
543
  job_storage[job_id]["result"] = result
544
 
545
- # Send webhook notification if callback URL exists
546
- job_data = job_storage[job_id]
547
- request_data = job_data["request"]
548
-
549
  if request_data.get("callback_url"):
550
  try:
551
  callback_url = request_data["callback_url"]
552
-
553
  callback_data = {
554
  "job_id": job_id,
555
  "status": status.value,
556
  "progress": progress,
557
  "message": message,
558
  "story_title": request_data["story_title"],
559
- "total_scenes": len(request_data["scenes"]),
560
- "timestamp": time.time(),
561
- "source": "huggingface-image-generator",
562
- "estimated_time_remaining": calculate_remaining_time(job_id, progress)
563
  }
564
 
565
- if status == JobStatus.PROCESSING:
566
- total_scenes = len(request_data["scenes"])
567
- if total_scenes > 0:
568
- current_scene = min((progress - 5) // (90 // total_scenes) + 1, total_scenes)
569
- callback_data["current_scene"] = current_scene
570
- callback_data["total_scenes"] = total_scenes
571
-
572
- if current_scene <= len(request_data["scenes"]):
573
- scene_data = request_data["scenes"][current_scene-1]
574
- callback_data["scene_description"] = scene_data.get("visual", "")[:100] + "..."
575
- callback_data["current_prompt"] = scene_data.get("visual", "")
576
-
577
  if status == JobStatus.COMPLETED and result:
578
  callback_data["result"] = {
579
- "total_pages": result.get("total_pages", 0),
580
- "generation_time": result.get("generation_time", 0),
581
- "hf_dataset_url": result.get("hf_dataset_url", ""),
582
- "pages_generated": result.get("generated_pages", 0),
583
- "consistency_seed": result.get("consistency_seed", None),
584
- "image_urls": result.get("image_urls", [])
585
  }
586
 
587
- headers = {
588
- 'Content-Type': 'application/json',
589
- 'User-Agent': 'Storybook-Generator/1.0'
590
- }
591
-
592
- print(f"📢 Sending callback to: {callback_url}")
593
-
594
- response = requests.post(
595
- callback_url,
596
- json=callback_data,
597
- headers=headers,
598
- timeout=30
599
- )
600
-
601
- print(f"📢 Callback sent: Status {response.status_code}")
602
-
603
  except Exception as e:
604
- print(f"⚠️ Callback failed: {str(e)}")
605
 
606
  return True
607
 
@@ -622,28 +594,21 @@ def calculate_remaining_time(job_id, progress):
622
 
623
  return "Unknown"
624
 
625
- # UPDATED BACKGROUND TASK - Uses HF Dataset instead of OCI
626
  def generate_storybook_background(job_id: str):
627
- """Background task to generate complete storybook and upload to HF Dataset"""
628
  try:
629
- # Ensure HF Dataset exists
630
  if HF_TOKEN:
631
  ensure_dataset_exists()
632
 
633
  job_data = job_storage[job_id]
634
- story_request_data = job_data["request"]
635
- story_request = StorybookRequest(**story_request_data)
636
 
637
- # Use project_id from request or generate from story title
638
  project_id = story_request.project_id or story_request.story_title.replace(' ', '_').lower()
639
 
640
  print(f"🎬 Starting storybook generation for job {job_id}")
641
- print(f"📖 Story: {story_request.story_title}")
642
- print(f"📄 Scenes: {len(story_request.scenes)}")
643
- print(f"🎨 Style: {story_request.style}")
644
- print(f"📦 Project ID: {project_id}")
645
 
646
- update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting storybook generation with pure prompts...")
647
 
648
  total_scenes = len(story_request.scenes)
649
  generated_pages = []
@@ -657,14 +622,11 @@ def generate_storybook_background(job_id: str):
657
  job_id,
658
  JobStatus.PROCESSING,
659
  progress,
660
- f"Generating page {i+1}/{total_scenes}: {scene.visual[:50]}..."
661
  )
662
 
663
  try:
664
- print(f"🖼️ Generating page {i+1}")
665
- print(f"📝 Pure prompt: {scene.visual}")
666
-
667
- # Generate image using pure prompt only
668
  image = generate_image_simple(
669
  scene.visual,
670
  story_request.model_choice,
@@ -673,9 +635,8 @@ def generate_storybook_background(job_id: str):
673
  story_request.consistency_seed
674
  )
675
 
676
- # Save locally as backup
677
  local_filepath, local_filename = save_image_to_local(image, scene.visual, story_request.style)
678
- print(f"💾 Image saved locally as backup: {local_filename}")
679
 
680
  # Upload to HF Dataset
681
  hf_url = None
@@ -690,161 +651,108 @@ def generate_storybook_background(job_id: str):
690
 
691
  if hf_url:
692
  image_urls.append(hf_url)
693
- print(f"✅ Uploaded to HF Dataset: {hf_url}")
694
 
695
- # Store page data
696
  page_data = {
697
  "page_number": i + 1,
698
  "image_url": hf_url or f"local://{local_filepath}",
699
- "hf_dataset_url": hf_url,
700
  "text_content": scene.text,
701
- "visual_description": scene.visual,
702
- "prompt_used": scene.visual,
703
- "local_backup_path": local_filepath
704
  }
705
  generated_pages.append(page_data)
706
 
707
- print(f"✅ Page {i+1} completed")
 
 
 
708
 
709
  except Exception as e:
710
- error_msg = f"Failed to generate page {i+1}: {str(e)}"
711
- print(f"❌ {error_msg}")
712
- update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
713
  return
714
 
715
- # Complete the job
716
  generation_time = time.time() - start_time
717
 
718
- # Count successful HF uploads
719
- hf_success_count = len(image_urls)
720
- local_fallback_count = total_scenes - hf_success_count
721
-
722
  result = {
723
  "story_title": story_request.story_title,
724
  "project_id": project_id,
725
  "total_pages": total_scenes,
726
- "generated_pages": len(generated_pages),
727
  "generation_time": round(generation_time, 2),
728
  "hf_dataset_url": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
729
- "consistency_seed": story_request.consistency_seed,
730
- "pages": generated_pages,
731
  "image_urls": image_urls,
732
- "upload_summary": {
733
- "hf_successful": hf_success_count,
734
- "local_fallback": local_fallback_count,
735
- "total_attempted": total_scenes
736
- }
737
  }
738
 
739
- status_message = f"🎉 Storybook completed! {len(generated_pages)} pages created in {generation_time:.2f}s."
740
- if hf_success_count > 0:
741
- status_message += f" {hf_success_count} images uploaded to HF Dataset."
742
- if local_fallback_count > 0:
743
- status_message += f" {local_fallback_count} pages saved locally."
744
-
745
  update_job_status(
746
  job_id,
747
  JobStatus.COMPLETED,
748
  100,
749
- status_message,
750
  result
751
  )
752
 
753
- print(f"🎉 Storybook generation finished for job {job_id}")
754
- print(f"📤 HF Uploads: {hf_success_count} successful, {local_fallback_count} local fallbacks")
755
-
756
  except Exception as e:
757
- error_msg = f"Story generation failed: {str(e)}"
758
  print(f"❌ {error_msg}")
759
  update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
760
 
761
- # FASTAPI ENDPOINTS (for n8n)
762
  @app.post("/api/generate-storybook")
763
  async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
764
- """Main endpoint for n8n integration - generates complete storybook using pure prompts"""
765
  try:
766
- print(f"📥 Received n8n request for story: {request.get('story_title', 'Unknown')}")
767
 
768
- # Add consistency seed if not provided
769
- if 'consistency_seed' not in request or not request['consistency_seed']:
770
  request['consistency_seed'] = random.randint(1000, 9999)
771
- print(f"🌱 Generated consistency seed: {request['consistency_seed']}")
772
 
773
- # Generate project_id if not provided
774
  if 'project_id' not in request:
775
  request['project_id'] = request.get('story_title', 'unknown').replace(' ', '_').lower()
776
 
777
- # Convert to Pydantic model
778
  story_request = StorybookRequest(**request)
779
 
780
- # Validate required fields
781
  if not story_request.story_title or not story_request.scenes:
782
- raise HTTPException(status_code=400, detail="story_title and scenes are required")
783
 
784
- # Create job immediately
785
  job_id = create_job(story_request)
786
-
787
- # Start background processing
788
  background_tasks.add_task(generate_storybook_background, job_id)
789
 
790
- # Immediate response for n8n
791
- response_data = {
792
  "status": "success",
793
- "message": "Storybook generation started",
794
  "job_id": job_id,
795
  "story_title": story_request.story_title,
796
  "project_id": request['project_id'],
797
  "total_scenes": len(story_request.scenes),
798
- "consistency_seed": story_request.consistency_seed,
799
  "hf_dataset": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
800
- "callback_url": story_request.callback_url,
801
- "estimated_time_seconds": len(story_request.scenes) * 35,
802
- "timestamp": datetime.now().isoformat()
803
  }
804
 
805
- print(f"✅ Job {job_id} started for: {story_request.story_title}")
806
-
807
- return response_data
808
-
809
  except Exception as e:
810
- error_msg = f"API Error: {str(e)}"
811
- print(f"❌ {error_msg}")
812
- raise HTTPException(status_code=500, detail=error_msg)
813
 
814
  @app.get("/api/job-status/{job_id}")
815
- async def get_job_status_endpoint(job_id: str):
816
- """Check job status"""
817
  job_data = job_storage.get(job_id)
818
  if not job_data:
819
  raise HTTPException(status_code=404, detail="Job not found")
820
 
821
- return JobStatusResponse(
822
- job_id=job_id,
823
- status=job_data["status"],
824
- progress=job_data["progress"],
825
- message=job_data["message"],
826
- result=job_data["result"],
827
- created_at=job_data["created_at"],
828
- updated_at=job_data["updated_at"]
829
- )
830
 
831
  @app.get("/api/health")
832
- async def api_health():
833
- """Health check endpoint for n8n"""
834
  return {
835
  "status": "healthy",
836
  "service": "storybook-generator",
837
  "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
838
- "hf_token_set": bool(HF_TOKEN),
839
- "timestamp": datetime.now().isoformat(),
840
- "active_jobs": len(job_storage),
841
- "models_loaded": list(model_cache.keys())
842
  }
843
 
844
- # NEW: Endpoint to get project images from HF Dataset
845
  @app.get("/api/project-images/{project_id}")
846
  async def get_project_images(project_id: str):
847
- """Get all images for a project from HF Dataset"""
848
  try:
849
  if not HF_TOKEN:
850
  return {"error": "HF_TOKEN not set"}
@@ -853,470 +761,65 @@ async def get_project_images(project_id: str):
853
  files = api.list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
854
 
855
  project_files = [f for f in files if f.startswith(f"data/projects/{project_id}/")]
856
-
857
  urls = [f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{f}" for f in project_files]
858
 
859
- return {
860
- "project_id": project_id,
861
- "total_images": len(urls),
862
- "image_urls": urls
863
- }
864
  except Exception as e:
865
  return {"error": str(e)}
866
 
867
- # NEW MEMORY MANAGEMENT ENDPOINTS
868
- @app.get("/api/memory-status")
869
- async def get_memory_status():
870
- """Get current memory usage and system status"""
871
- memory_info = get_memory_usage()
872
- return MemoryStatusResponse(
873
- memory_used_mb=memory_info["memory_used_mb"],
874
- memory_percent=memory_info["memory_percent"],
875
- models_loaded=memory_info["models_loaded"],
876
- active_jobs=memory_info["active_jobs"],
877
- local_images_count=memory_info["local_images_count"],
878
- gpu_memory_allocated_mb=memory_info["gpu_memory_allocated_mb"],
879
- gpu_memory_cached_mb=memory_info["gpu_memory_cached_mb"],
880
- status="healthy"
881
- )
882
-
883
- @app.post("/api/clear-memory")
884
- async def clear_memory_endpoint(request: MemoryClearanceRequest):
885
- """Clear memory by unloading models and cleaning up resources"""
886
- try:
887
- result = clear_memory(
888
- clear_models=request.clear_models,
889
- clear_jobs=request.clear_jobs,
890
- clear_local_images=request.clear_local_images,
891
- force_gc=request.force_gc
892
- )
893
-
894
- return {
895
- "status": "success",
896
- "message": "Memory clearance completed",
897
- "details": result
898
- }
899
-
900
- except Exception as e:
901
- raise HTTPException(status_code=500, detail=f"Memory clearance failed: {str(e)}")
902
-
903
- @app.post("/api/auto-cleanup")
904
- async def auto_cleanup():
905
- """Automatic cleanup - clears completed jobs and forces GC"""
906
- try:
907
- result = clear_memory(
908
- clear_models=False, # Don't clear models by default
909
- clear_jobs=True, # Clear completed jobs
910
- clear_local_images=False, # Don't clear images by default
911
- force_gc=True # Force garbage collection
912
- )
913
-
914
- return {
915
- "status": "success",
916
- "message": "Automatic cleanup completed",
917
- "details": result
918
- }
919
-
920
- except Exception as e:
921
- raise HTTPException(status_code=500, detail=f"Auto cleanup failed: {str(e)}")
922
-
923
- @app.get("/api/local-images")
924
- async def get_local_images():
925
- """API endpoint to get locally saved test images"""
926
- storage_info = get_local_storage_info()
927
- return storage_info
928
-
929
- @app.delete("/api/local-images/{filename:path}")
930
- async def delete_local_image_api(filename: str):
931
- """API endpoint to delete a local image"""
932
- try:
933
- filepath = os.path.join(PERSISTENT_IMAGE_DIR, filename)
934
- success, message = delete_local_image(filepath)
935
- return {"status": "success" if success else "error", "message": message}
936
- except Exception as e:
937
- return {"status": "error", "message": str(e)}
938
-
939
- # SIMPLE GRADIO INTERFACE
940
  def create_gradio_interface():
941
- """Create simple Gradio interface for testing"""
942
-
943
- def generate_test_image_simple(prompt, model_choice, style_choice):
944
- """Generate a single image using pure prompt only"""
945
- try:
946
- if not prompt.strip():
947
- return None, "❌ Please enter a prompt", None
948
-
949
- print(f"🎨 Generating test image with pure prompt: {prompt}")
950
-
951
- # Generate the image using pure prompt
952
- image = generate_image_simple(
953
- prompt,
954
- model_choice,
955
- style_choice,
956
- 1
957
- )
958
-
959
- # Save to local storage
960
- filepath, filename = save_image_to_local(image, prompt, style_choice)
961
-
962
- status_msg = f"""✅ Success! Generated: {prompt}
963
-
964
- 📁 **Local file:** {filename if filename else 'Not saved'}"""
965
-
966
- return image, status_msg, filepath
967
-
968
- except Exception as e:
969
- error_msg = f"❌ Generation failed: {str(e)}"
970
- print(error_msg)
971
- return None, error_msg, None
972
-
973
- with gr.Blocks(title="Simple Image Generator", theme="soft") as demo:
974
- gr.Markdown("# 🎨 Simple Image Generator")
975
- gr.Markdown("Generate images using **pure prompts only** - no automatic enhancements")
976
-
977
- # Storage info display
978
- storage_info = gr.Textbox(
979
- label="📊 Local Storage Information",
980
- interactive=False,
981
- lines=2
982
- )
983
-
984
- # Memory status display
985
- memory_status = gr.Textbox(
986
- label="🧠 Memory Status",
987
- interactive=False,
988
- lines=3
989
- )
990
-
991
- # HF Dataset status
992
- hf_status = gr.Textbox(
993
- label="📤 Hugging Face Dataset",
994
- value=f"✅ Connected to {DATASET_ID}" if HF_TOKEN else "❌ HF_TOKEN not set - local only",
995
- interactive=False,
996
- lines=2
997
- )
998
 
999
- def update_storage_info():
1000
- info = get_local_storage_info()
1001
- if "error" not in info:
1002
- return f"📁 Local Storage: {info['total_files']} images, {info['total_size_mb']} MB used"
1003
- return "📁 Local Storage: Unable to calculate"
1004
 
1005
- def update_memory_status():
1006
- memory_info = get_memory_usage()
1007
- status_text = f"🧠 Memory Usage: {memory_info['memory_used_mb']} MB ({memory_info['memory_percent']}%)\n"
1008
- status_text += f"📦 Models Loaded: {memory_info['models_loaded']}\n"
1009
- status_text += f"⚡ Active Jobs: {memory_info['active_jobs']}"
1010
-
1011
- if memory_info['gpu_memory_allocated_mb']:
1012
- status_text += f"\n🎮 GPU Memory: {memory_info['gpu_memory_allocated_mb']} MB allocated"
1013
-
1014
- return status_text
1015
 
1016
  with gr.Row():
1017
- with gr.Column(scale=1):
1018
- gr.Markdown("### 🎯 Quality Settings")
1019
-
1020
- model_dropdown = gr.Dropdown(
1021
- label="AI Model",
1022
- choices=list(MODEL_CHOICES.keys()),
1023
- value="dreamshaper-8"
1024
- )
1025
-
1026
- style_dropdown = gr.Dropdown(
1027
- label="Art Style",
1028
- choices=["childrens_book", "realistic", "fantasy", "anime"],
1029
- value="anime"
1030
- )
1031
-
1032
- prompt_input = gr.Textbox(
1033
- label="Pure Prompt",
1034
- placeholder="Enter your exact prompt...",
1035
- lines=3
1036
- )
1037
-
1038
- generate_btn = gr.Button("✨ Generate Image", variant="primary")
1039
-
1040
- # Current image management
1041
- current_file_path = gr.State()
1042
- delete_btn = gr.Button("🗑️ Delete This Image", variant="stop")
1043
- delete_status = gr.Textbox(label="Delete Status", interactive=False, lines=2)
1044
-
1045
- # Memory management section
1046
- gr.Markdown("### 🧠 Memory Management")
1047
- with gr.Row():
1048
- auto_cleanup_btn = gr.Button("🔄 Auto Cleanup", size="sm")
1049
- clear_models_btn = gr.Button("🗑️ Clear Models", variant="stop", size="sm")
1050
-
1051
- memory_clear_status = gr.Textbox(label="Memory Clear Status", interactive=False, lines=2)
1052
-
1053
- gr.Markdown("### 📚 API Usage for n8n")
1054
- gr.Markdown(f"""
1055
- **Generate Storybook:**
1056
- - Endpoint: `POST /api/generate-storybook`
1057
- - Body: `{{"story_title": "...", "scenes": [...]}}`
1058
-
1059
- **Check Status:**
1060
- - `GET /api/job-status/{{job_id}}`
1061
-
1062
- **HF Dataset:**
1063
- - `{DATASET_ID if HF_TOKEN else "Set HF_TOKEN to enable"}`
1064
- """)
1065
-
1066
- with gr.Column(scale=2):
1067
- image_output = gr.Image(label="Generated Image", height=500, show_download_button=True)
1068
- status_output = gr.Textbox(label="Status", interactive=False, lines=4)
1069
-
1070
- # Local file management section
1071
- with gr.Accordion("📁 Manage Local Test Images", open=True):
1072
- gr.Markdown("### Locally Saved Images")
1073
-
1074
- with gr.Row():
1075
- refresh_btn = gr.Button("🔄 Refresh List")
1076
- clear_all_btn = gr.Button("🗑️ Clear All Images", variant="stop")
1077
 
1078
- file_gallery = gr.Gallery(
1079
- label="Local Images",
1080
- show_label=True,
1081
- elem_id="gallery",
1082
- columns=4,
1083
- height="auto"
1084
- )
1085
-
1086
- clear_status = gr.Textbox(label="Clear Status", interactive=False)
1087
-
1088
- def delete_current_image(filepath):
1089
- """Delete the currently displayed image"""
1090
- if not filepath:
1091
- return "❌ No image to delete", None, None, refresh_local_images()
1092
-
1093
- success, message = delete_local_image(filepath)
1094
- updated_files = refresh_local_images()
1095
-
1096
- if success:
1097
- status_msg = f"✅ {message}"
1098
- return status_msg, None, "Image deleted successfully!", updated_files
1099
- else:
1100
- return f"❌ {message}", None, "Delete failed", updated_files
1101
-
1102
- def clear_all_images():
1103
- """Delete all local images"""
1104
- try:
1105
- storage_info = get_local_storage_info()
1106
- deleted_count = 0
1107
-
1108
- if "images" in storage_info:
1109
- for image_info in storage_info["images"]:
1110
- success, _ = delete_local_image(image_info["path"])
1111
- if success:
1112
- deleted_count += 1
1113
-
1114
- updated_files = refresh_local_images()
1115
- return f"✅ Deleted {deleted_count} images", updated_files
1116
- except Exception as e:
1117
- return f"❌ Error: {str(e)}", refresh_local_images()
1118
-
1119
- def perform_auto_cleanup():
1120
- """Perform automatic cleanup"""
1121
- try:
1122
- result = clear_memory(
1123
- clear_models=False,
1124
- clear_jobs=True,
1125
- clear_local_images=False,
1126
- force_gc=True
1127
- )
1128
- return f"✅ Auto cleanup completed: {len(result['actions_performed'])} actions"
1129
- except Exception as e:
1130
- return f"❌ Auto cleanup failed: {str(e)}"
1131
-
1132
- def clear_models():
1133
- """Clear all loaded models"""
1134
- try:
1135
- result = clear_memory(
1136
- clear_models=True,
1137
- clear_jobs=False,
1138
- clear_local_images=False,
1139
- force_gc=True
1140
- )
1141
- return f"✅ Models cleared: {len(result['actions_performed'])} actions"
1142
- except Exception as e:
1143
- return f"❌ Model clearance failed: {str(e)}"
1144
-
1145
- # Connect buttons to functions
1146
- generate_btn.click(
1147
- fn=generate_test_image_simple,
1148
- inputs=[prompt_input, model_dropdown, style_dropdown],
1149
- outputs=[image_output, status_output, current_file_path]
1150
- ).then(
1151
- fn=refresh_local_images,
1152
- outputs=file_gallery
1153
- ).then(
1154
- fn=update_storage_info,
1155
- outputs=storage_info
1156
- ).then(
1157
- fn=update_memory_status,
1158
- outputs=memory_status
1159
- )
1160
-
1161
- delete_btn.click(
1162
- fn=delete_current_image,
1163
- inputs=current_file_path,
1164
- outputs=[delete_status, image_output, status_output, file_gallery]
1165
- ).then(
1166
- fn=update_storage_info,
1167
- outputs=storage_info
1168
- ).then(
1169
- fn=update_memory_status,
1170
- outputs=memory_status
1171
- )
1172
-
1173
- refresh_btn.click(
1174
- fn=refresh_local_images,
1175
- outputs=file_gallery
1176
- ).then(
1177
- fn=update_storage_info,
1178
- outputs=storage_info
1179
- ).then(
1180
- fn=update_memory_status,
1181
- outputs=memory_status
1182
- )
1183
 
1184
- clear_all_btn.click(
1185
- fn=clear_all_images,
1186
- outputs=[clear_status, file_gallery]
1187
- ).then(
1188
- fn=update_storage_info,
1189
- outputs=storage_info
1190
- ).then(
1191
- fn=update_memory_status,
1192
- outputs=memory_status
1193
- )
1194
-
1195
- # Memory management buttons
1196
- auto_cleanup_btn.click(
1197
- fn=perform_auto_cleanup,
1198
- outputs=memory_clear_status
1199
- ).then(
1200
- fn=update_memory_status,
1201
- outputs=memory_status
1202
- )
1203
-
1204
- clear_models_btn.click(
1205
- fn=clear_models,
1206
- outputs=memory_clear_status
1207
- ).then(
1208
- fn=update_memory_status,
1209
- outputs=memory_status
1210
- )
1211
-
1212
- # Initialize on load
1213
- demo.load(fn=refresh_local_images, outputs=file_gallery)
1214
- demo.load(fn=update_storage_info, outputs=storage_info)
1215
- demo.load(fn=update_memory_status, outputs=memory_status)
1216
 
1217
  return demo
1218
 
1219
- # Create simple Gradio app
1220
  demo = create_gradio_interface()
1221
 
1222
- # Simple root endpoint
1223
  @app.get("/")
1224
  async def root():
1225
  return {
1226
- "message": "Storybook Generator API with HF Dataset is running!",
1227
- "api_endpoints": {
1228
- "health_check": "GET /api/health",
1229
- "generate_storybook": "POST /api/generate-storybook",
1230
- "check_job_status": "GET /api/job-status/{job_id}",
1231
- "project_images": "GET /api/project-images/{project_id}",
1232
- "local_images": "GET /api/local-images",
1233
- "memory_status": "GET /api/memory-status",
1234
- "clear_memory": "POST /api/clear-memory",
1235
- "auto_cleanup": "POST /api/auto-cleanup"
1236
- },
1237
  "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
1238
- "features": {
1239
- "pure_prompts": " Enabled - No automatic enhancements",
1240
- "n8n_integration": " Enabled",
1241
- "memory_management": " Enabled",
1242
- "hf_dataset": " Enabled" if HF_TOKEN else "❌ Disabled"
1243
  },
1244
- "web_interface": "GET /ui"
1245
- }
1246
-
1247
- # Add a simple test endpoint
1248
- @app.get("/api/test")
1249
- async def test_endpoint():
1250
- return {
1251
- "status": "success",
1252
- "message": "API is working correctly",
1253
- "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
1254
- "timestamp": datetime.now().isoformat()
1255
  }
1256
 
1257
- # For Hugging Face Spaces deployment
1258
- def get_app():
1259
- return app
1260
-
1261
  if __name__ == "__main__":
1262
  import uvicorn
1263
- import os
1264
-
1265
- # Check if we're running on Hugging Face Spaces
1266
- HF_SPACE = os.environ.get('SPACE_ID') is not None
1267
 
1268
- if HF_SPACE:
1269
- print("🚀 Running on Hugging Face Spaces - Integrated Mode")
1270
  print(f"📦 HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
1271
- print("📚 API endpoints available at: /api/*")
1272
- print("🎨 Web interface available at: /ui")
1273
- print("📝 PURE PROMPTS enabled - no automatic enhancements")
1274
- print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available")
1275
-
1276
- # Mount Gradio without reassigning app
1277
  gr.mount_gradio_app(app, demo, path="/ui")
1278
-
1279
- # Run the combined app
1280
- uvicorn.run(
1281
- app,
1282
- host="0.0.0.0",
1283
- port=7860,
1284
- log_level="info"
1285
- )
1286
  else:
1287
- # Local development - run separate servers
1288
- print("🚀 Running locally - Separate API and UI servers")
1289
- print(f"📦 HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
1290
- print("📚 API endpoints: http://localhost:8000/api/*")
1291
- print("🎨 Web interface: http://localhost:7860/ui")
1292
- print("📝 PURE PROMPTS enabled - no automatic enhancements")
1293
- print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available")
1294
-
1295
- def run_fastapi():
1296
- """Run FastAPI on port 8000 for API calls"""
1297
- uvicorn.run(
1298
- app,
1299
- host="0.0.0.0",
1300
- port=8000,
1301
- log_level="info",
1302
- access_log=False
1303
- )
1304
-
1305
- def run_gradio():
1306
- """Run Gradio on port 7860 for web interface"""
1307
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
1308
-
1309
- # Run both servers in separate threads
1310
- import threading
1311
- fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
1312
- gradio_thread = threading.Thread(target=run_gradio, daemon=True)
1313
-
1314
- fastapi_thread.start()
1315
- gradio_thread.start()
1316
-
1317
- try:
1318
- # Keep main thread alive
1319
- while True:
1320
- time.sleep(1)
1321
- except KeyboardInterrupt:
1322
- print("🛑 Shutting down servers...")
 
22
  import time
23
  from requests.adapters import HTTPAdapter
24
  from urllib3.util.retry import Retry
25
+ from huggingface_hub import HfApi
26
+ import accelerate # Add this for better memory management
27
 
28
  # =============================================
29
+ # MEMORY OPTIMIZATION SETTINGS
30
+ # =============================================
31
+ # Enable memory efficient attention if available
32
+ if hasattr(torch, 'backends') and hasattr(torch.backends, 'cuda') and torch.backends.cuda.is_enabled():
33
+ torch.backends.cuda.enable_flash_sdp(True)
34
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
35
+
36
+ # Set environment variables for memory optimization
37
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
38
+ os.environ["OMP_NUM_THREADS"] = "1"
39
+ os.environ["MKL_NUM_THREADS"] = "1"
40
+
41
+ # =============================================
42
+ # HUGGING FACE DATASET CONFIGURATION
43
  # =============================================
44
  HF_TOKEN = os.environ.get("HF_TOKEN")
45
  HF_USERNAME = "yukee1992"
 
91
  style: str = "childrens_book"
92
  callback_url: Optional[str] = None
93
  consistency_seed: Optional[int] = None
94
+ project_id: Optional[str] = None
95
 
96
  class JobStatusResponse(BaseModel):
97
  job_id: str
 
118
  gpu_memory_cached_mb: Optional[float] = None
119
  status: str
120
 
121
+ # HIGH-QUALITY MODEL SELECTION
122
  MODEL_CHOICES = {
123
  "dreamshaper-8": "lykon/dreamshaper-8",
124
  "realistic-vision": "SG161222/Realistic_Vision_V5.1",
 
146
  memory_used_mb = memory_info.rss / (1024 * 1024)
147
  memory_percent = process.memory_percent()
148
 
 
149
  gpu_memory_allocated_mb = None
150
  gpu_memory_cached_mb = None
151
 
 
167
  """Clear memory by unloading models and cleaning up resources"""
168
  results = []
169
 
 
170
  if clear_models:
171
  with model_lock:
172
  models_cleared = len(model_cache)
173
  for model_name, pipe in model_cache.items():
174
  try:
 
175
  if hasattr(pipe, 'to'):
176
  pipe.to('cpu')
 
 
177
  del pipe
178
  results.append(f"Unloaded model: {model_name}")
179
  except Exception as e:
 
185
  current_model_name = None
186
  results.append(f"Cleared {models_cleared} models from cache")
187
 
 
188
  if clear_jobs:
189
  jobs_to_clear = []
190
  for job_id, job_data in job_storage.items():
 
197
 
198
  results.append(f"Cleared {len(jobs_to_clear)} completed/failed jobs")
199
 
 
200
  if clear_local_images:
201
  try:
202
  storage_info = get_local_storage_info()
 
210
  except Exception as e:
211
  results.append(f"Error clearing local images: {str(e)}")
212
 
 
213
  if force_gc:
214
  gc.collect()
215
  if torch.cuda.is_available():
 
218
  results.append("GPU cache cleared")
219
  results.append("Garbage collection forced")
220
 
 
221
  memory_status = get_memory_usage()
222
 
223
  return {
 
226
  "memory_after_cleanup": memory_status
227
  }
228
 
229
+ # =============================================
230
+ # OPTIMIZED MODEL LOADING - MAINTAINS QUALITY
231
+ # =============================================
232
  def load_model(model_name="dreamshaper-8"):
233
+ """Thread-safe model loading with memory optimization but maintaining quality"""
234
  global model_cache, current_model_name, current_pipe
235
 
236
  with model_lock:
 
239
  current_model_name = model_name
240
  return current_pipe
241
 
242
+ print(f"🔄 Loading model: {model_name}")
243
  try:
244
  model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
245
 
246
+ # Load with memory optimizations but keep quality
 
247
  pipe = StableDiffusionPipeline.from_pretrained(
248
  model_id,
249
  torch_dtype=torch.float32,
250
  safety_checker=None,
251
  requires_safety_checker=False,
252
+ cache_dir="./model_cache",
253
+ low_cpu_mem_usage=True, # Reduces memory during loading
254
+ use_safetensors=True,
255
+ variant="fp32" # Use full precision for quality
256
  )
257
 
258
+ # Use memory efficient scheduler (maintains quality)
259
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
260
+
261
+ # Enable attention slicing (trades speed for memory)
262
+ pipe.enable_attention_slicing()
263
+
264
+ # Enable sequential CPU offload if needed
265
+ if not torch.cuda.is_available():
266
+ pipe.enable_sequential_cpu_offload()
267
+
268
  pipe = pipe.to("cpu")
269
 
270
  model_cache[model_name] = pipe
271
  current_pipe = pipe
272
  current_model_name = model_name
273
 
274
+ print(f"✅ Model loaded: {model_name}")
275
  return pipe
276
 
277
  except Exception as e:
278
  print(f"❌ Model loading failed for {model_name}: {e}")
279
  print(f"🔄 Falling back to stable-diffusion-v1-5")
280
 
 
281
  try:
282
  pipe = StableDiffusionPipeline.from_pretrained(
283
  "runwayml/stable-diffusion-v1-5",
284
  torch_dtype=torch.float32,
285
  safety_checker=None,
286
+ requires_safety_checker=False,
287
+ low_cpu_mem_usage=True
288
  ).to("cpu")
289
 
290
+ pipe.enable_attention_slicing()
291
+
292
  model_cache[model_name] = pipe
293
  current_pipe = pipe
294
  current_model_name = "sd-1.5"
295
 
296
+ print(f"✅ Fallback model loaded")
297
  return pipe
298
 
299
  except Exception as fallback_error:
300
+ print(f"❌ Fallback model failed: {fallback_error}")
301
  raise
302
 
303
  # Initialize default model
 
305
  load_model("dreamshaper-8")
306
  print("✅ Model loaded and ready!")
307
 
308
+ # =============================================
309
+ # HF DATASET FUNCTIONS
310
+ # =============================================
311
+ def ensure_dataset_exists():
312
+ """Create dataset if it doesn't exist"""
313
+ if not HF_TOKEN:
314
+ print("⚠️ HF_TOKEN not set, cannot create/verify dataset")
315
+ return False
316
+
317
+ try:
318
+ api = HfApi(token=HF_TOKEN)
319
+ try:
320
+ api.dataset_info(DATASET_ID)
321
+ print(f"✅ Dataset {DATASET_ID} exists")
322
+ except Exception:
323
+ print(f"📦 Creating dataset: {DATASET_ID}")
324
+ api.create_repo(
325
+ repo_id=DATASET_ID,
326
+ repo_type="dataset",
327
+ private=False,
328
+ exist_ok=True
329
+ )
330
+ print(f"✅ Created dataset: {DATASET_ID}")
331
+ return True
332
+ except Exception as e:
333
+ print(f"❌ Failed to ensure dataset: {e}")
334
+ return False
335
+
336
+ def upload_to_hf_dataset(file_content, filename, subfolder=""):
337
+ """Upload a file to Hugging Face Dataset"""
338
+ if not HF_TOKEN:
339
+ print("⚠️ HF_TOKEN not set, skipping upload")
340
+ return None
341
+
342
+ try:
343
+ if subfolder:
344
+ path_in_repo = f"data/{subfolder}/{filename}"
345
+ else:
346
+ path_in_repo = f"data/{filename}"
347
+
348
+ api = HfApi(token=HF_TOKEN)
349
+ api.upload_file(
350
+ path_or_fileobj=file_content,
351
+ path_in_repo=path_in_repo,
352
+ repo_id=DATASET_ID,
353
+ repo_type="dataset"
354
+ )
355
+
356
+ url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}"
357
+ print(f"✅ Uploaded to HF Dataset: {url}")
358
+ return url
359
+
360
+ except Exception as e:
361
+ print(f"❌ Failed to upload to HF Dataset: {e}")
362
+ return None
363
+
364
+ def upload_image_to_hf_dataset(image, project_id, page_number, prompt, style=""):
365
+ """Upload generated image to HF Dataset"""
366
+ try:
367
+ img_bytes = io.BytesIO()
368
+ image.save(img_bytes, format='PNG')
369
+ img_data = img_bytes.getvalue()
370
+
371
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
372
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
373
+ safe_prompt = safe_prompt.replace(' ', '_')
374
+ filename = f"page_{page_number:03d}_{safe_prompt}_{timestamp}.png"
375
+
376
+ subfolder = f"projects/{project_id}"
377
+ url = upload_to_hf_dataset(img_data, filename, subfolder)
378
+
379
+ return url
380
+
381
+ except Exception as e:
382
+ print(f"❌ Failed to upload image to HF Dataset: {e}")
383
+ return None
384
+
385
+ # PROMPT ENGINEERING
386
  def enhance_prompt_simple(scene_visual, style="childrens_book"):
387
  """Simple prompt enhancement - uses only the provided visual prompt with style"""
388
 
 
389
  style_templates = {
390
  "childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration",
391
  "realistic": "photorealistic, detailed, natural lighting, professional photography",
 
395
 
396
  style_prompt = style_templates.get(style, style_templates["childrens_book"])
397
 
 
398
  enhanced_prompt = f"{style_prompt}, {scene_visual}"
399
 
 
400
  negative_prompt = (
401
  "blurry, low quality, bad anatomy, deformed characters, "
402
  "wrong proportions, mismatched features"
 
404
 
405
  return enhanced_prompt, negative_prompt
406
 
407
+ # =============================================
408
+ # OPTIMIZED IMAGE GENERATION - MAINTAINS QUALITY
409
+ # =============================================
410
  def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None):
411
+ """Generate image with memory optimization but maintaining quality"""
412
 
 
413
  enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style)
414
 
 
415
  if consistency_seed:
416
  scene_seed = consistency_seed + scene_number
417
  else:
 
420
  try:
421
  pipe = load_model(model_choice)
422
 
423
+ # Use memory efficient generation
424
+ with torch.inference_mode(): # More memory efficient than no_grad
425
+ image = pipe(
426
+ prompt=enhanced_prompt,
427
+ negative_prompt=negative_prompt,
428
+ num_inference_steps=35, # Keep quality
429
+ guidance_scale=7.5,
430
+ width=768, # Keep quality
431
+ height=1024, # Keep quality
432
+ generator=torch.Generator(device="cpu").manual_seed(scene_seed)
433
+ ).images[0]
434
+
435
+ # Clean up after generation
436
+ if torch.cuda.is_available():
437
+ torch.cuda.empty_cache()
438
 
439
  print(f"✅ Generated image for scene {scene_number}")
 
 
 
440
  return image
441
 
442
  except Exception as e:
 
451
  safe_prompt = "".join(c for c in prompt[:50] if c.isalnum() or c in (' ', '-', '_')).rstrip()
452
  filename = f"image_{safe_prompt}_{timestamp}.png"
453
 
 
454
  style_dir = os.path.join(PERSISTENT_IMAGE_DIR, style)
455
  os.makedirs(style_dir, exist_ok=True)
456
  filepath = os.path.join(style_dir, filename)
457
 
 
458
  image.save(filepath)
459
  print(f"💾 Image saved locally: {filepath}")
460
 
 
469
  try:
470
  if os.path.exists(filepath):
471
  os.remove(filepath)
 
472
  return True, f"✅ Deleted: {os.path.basename(filepath)}"
473
  else:
474
  return False, f"❌ File not found: {filepath}"
 
520
  print(f"Error refreshing local images: {e}")
521
  return []
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  # JOB MANAGEMENT FUNCTIONS
524
  def create_job(story_request: StorybookRequest) -> str:
525
  job_id = str(uuid.uuid4())
 
536
  }
537
 
538
  print(f"📝 Created job {job_id} for story: {story_request.story_title}")
 
 
539
  return job_id
540
 
541
  def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None):
 
552
  if result:
553
  job_storage[job_id]["result"] = result
554
 
 
 
 
 
555
  if request_data.get("callback_url"):
556
  try:
557
  callback_url = request_data["callback_url"]
 
558
  callback_data = {
559
  "job_id": job_id,
560
  "status": status.value,
561
  "progress": progress,
562
  "message": message,
563
  "story_title": request_data["story_title"],
564
+ "timestamp": time.time()
 
 
 
565
  }
566
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  if status == JobStatus.COMPLETED and result:
568
  callback_data["result"] = {
569
+ "image_urls": result.get("image_urls", []),
570
+ "project_id": result.get("project_id", "")
 
 
 
 
571
  }
572
 
573
+ requests.post(callback_url, json=callback_data, timeout=5)
574
+ print(f"📢 Callback sent to {callback_url}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  except Exception as e:
576
+ print(f"⚠️ Callback failed: {e}")
577
 
578
  return True
579
 
 
594
 
595
  return "Unknown"
596
 
597
+ # OPTIMIZED BACKGROUND TASK
598
  def generate_storybook_background(job_id: str):
599
+ """Background task with memory optimization"""
600
  try:
 
601
  if HF_TOKEN:
602
  ensure_dataset_exists()
603
 
604
  job_data = job_storage[job_id]
605
+ story_request = StorybookRequest(**job_data["request"])
 
606
 
 
607
  project_id = story_request.project_id or story_request.story_title.replace(' ', '_').lower()
608
 
609
  print(f"🎬 Starting storybook generation for job {job_id}")
 
 
 
 
610
 
611
+ update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting generation...")
612
 
613
  total_scenes = len(story_request.scenes)
614
  generated_pages = []
 
622
  job_id,
623
  JobStatus.PROCESSING,
624
  progress,
625
+ f"Generating page {i+1}/{total_scenes}"
626
  )
627
 
628
  try:
629
+ # Generate image
 
 
 
630
  image = generate_image_simple(
631
  scene.visual,
632
  story_request.model_choice,
 
635
  story_request.consistency_seed
636
  )
637
 
638
+ # Save locally
639
  local_filepath, local_filename = save_image_to_local(image, scene.visual, story_request.style)
 
640
 
641
  # Upload to HF Dataset
642
  hf_url = None
 
651
 
652
  if hf_url:
653
  image_urls.append(hf_url)
 
654
 
 
655
  page_data = {
656
  "page_number": i + 1,
657
  "image_url": hf_url or f"local://{local_filepath}",
 
658
  "text_content": scene.text,
659
+ "visual_description": scene.visual
 
 
660
  }
661
  generated_pages.append(page_data)
662
 
663
+ # Clean up after each page
664
+ if torch.cuda.is_available():
665
+ torch.cuda.empty_cache()
666
+ gc.collect()
667
 
668
  except Exception as e:
669
+ print(f" Page {i+1} failed: {e}")
670
+ update_job_status(job_id, JobStatus.FAILED, progress, str(e))
 
671
  return
672
 
 
673
  generation_time = time.time() - start_time
674
 
 
 
 
 
675
  result = {
676
  "story_title": story_request.story_title,
677
  "project_id": project_id,
678
  "total_pages": total_scenes,
 
679
  "generation_time": round(generation_time, 2),
680
  "hf_dataset_url": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
 
 
681
  "image_urls": image_urls,
682
+ "pages": generated_pages
 
 
 
 
683
  }
684
 
 
 
 
 
 
 
685
  update_job_status(
686
  job_id,
687
  JobStatus.COMPLETED,
688
  100,
689
+ f"✅ Completed! {len(image_urls)} images uploaded",
690
  result
691
  )
692
 
 
 
 
693
  except Exception as e:
694
+ error_msg = f"Generation failed: {str(e)}"
695
  print(f"❌ {error_msg}")
696
  update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
697
 
698
+ # FASTAPI ENDPOINTS
699
  @app.post("/api/generate-storybook")
700
  async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
 
701
  try:
702
+ print(f"📥 Received request for: {request.get('story_title', 'Unknown')}")
703
 
704
+ if 'consistency_seed' not in request:
 
705
  request['consistency_seed'] = random.randint(1000, 9999)
 
706
 
 
707
  if 'project_id' not in request:
708
  request['project_id'] = request.get('story_title', 'unknown').replace(' ', '_').lower()
709
 
 
710
  story_request = StorybookRequest(**request)
711
 
 
712
  if not story_request.story_title or not story_request.scenes:
713
+ raise HTTPException(status_code=400, detail="story_title and scenes required")
714
 
 
715
  job_id = create_job(story_request)
 
 
716
  background_tasks.add_task(generate_storybook_background, job_id)
717
 
718
+ return {
 
719
  "status": "success",
 
720
  "job_id": job_id,
721
  "story_title": story_request.story_title,
722
  "project_id": request['project_id'],
723
  "total_scenes": len(story_request.scenes),
 
724
  "hf_dataset": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
725
+ "estimated_time_seconds": len(story_request.scenes) * 35
 
 
726
  }
727
 
 
 
 
 
728
  except Exception as e:
729
+ raise HTTPException(status_code=500, detail=str(e))
 
 
730
 
731
  @app.get("/api/job-status/{job_id}")
732
+ async def get_job_status(job_id: str):
 
733
  job_data = job_storage.get(job_id)
734
  if not job_data:
735
  raise HTTPException(status_code=404, detail="Job not found")
736
 
737
+ return {
738
+ "job_id": job_id,
739
+ "status": job_data["status"].value,
740
+ "progress": job_data["progress"],
741
+ "message": job_data["message"],
742
+ "result": job_data["result"]
743
+ }
 
 
744
 
745
  @app.get("/api/health")
746
+ async def health():
 
747
  return {
748
  "status": "healthy",
749
  "service": "storybook-generator",
750
  "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
751
+ "active_jobs": len(job_storage)
 
 
 
752
  }
753
 
 
754
  @app.get("/api/project-images/{project_id}")
755
  async def get_project_images(project_id: str):
 
756
  try:
757
  if not HF_TOKEN:
758
  return {"error": "HF_TOKEN not set"}
 
761
  files = api.list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
762
 
763
  project_files = [f for f in files if f.startswith(f"data/projects/{project_id}/")]
 
764
  urls = [f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{f}" for f in project_files]
765
 
766
+ return {"project_id": project_id, "total_images": len(urls), "image_urls": urls}
 
 
 
 
767
  except Exception as e:
768
  return {"error": str(e)}
769
 
770
+ # GRADIO INTERFACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  def create_gradio_interface():
772
+ def generate_test(prompt, model_choice, style_choice):
773
+ if not prompt.strip():
774
+ return None, "❌ Please enter a prompt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
+ image = generate_image_simple(prompt, model_choice, style_choice, 1)
777
+ filepath, filename = save_image_to_local(image, prompt, style_choice)
 
 
 
778
 
779
+ return image, f"✅ Generated! Local: {filename}"
780
+
781
+ with gr.Blocks(title="Storybook Generator") as demo:
782
+ gr.Markdown("# 🎨 Storybook Generator")
 
 
 
 
 
 
783
 
784
  with gr.Row():
785
+ with gr.Column():
786
+ model = gr.Dropdown(choices=list(MODEL_CHOICES.keys()), value="dreamshaper-8", label="Model")
787
+ style = gr.Dropdown(choices=["childrens_book", "realistic", "fantasy", "anime"], value="anime", label="Style")
788
+ prompt = gr.Textbox(label="Prompt", lines=3)
789
+ btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
 
791
+ with gr.Column():
792
+ output = gr.Image(label="Generated Image", height=500)
793
+ status = gr.Textbox(label="Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
+ btn.click(fn=generate_test, inputs=[prompt, model, style], outputs=[output, status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
  return demo
798
 
 
799
  demo = create_gradio_interface()
800
 
 
801
  @app.get("/")
802
  async def root():
803
  return {
804
+ "message": "Storybook Generator API",
 
 
 
 
 
 
 
 
 
 
805
  "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
806
+ "endpoints": {
807
+ "generate": "POST /api/generate-storybook",
808
+ "status": "GET /api/job-status/{job_id}",
809
+ "health": "GET /api/health",
810
+ "project_images": "GET /api/project-images/{project_id}"
811
  },
812
+ "ui": "/ui"
 
 
 
 
 
 
 
 
 
 
813
  }
814
 
 
 
 
 
815
  if __name__ == "__main__":
816
  import uvicorn
 
 
 
 
817
 
818
+ if os.environ.get('SPACE_ID'):
819
+ print("🚀 Running on Hugging Face Spaces")
820
  print(f"📦 HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
 
 
 
 
 
 
821
  gr.mount_gradio_app(app, demo, path="/ui")
822
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
 
 
 
 
 
823
  else:
824
+ print("🚀 Running locally")
825
+ uvicorn.run(app, host="0.0.0.0", port=8000)