Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from PIL import Image | |
| import io | |
| import requests | |
| import os | |
| from datetime import datetime | |
| import re | |
| import time | |
| import json | |
| from typing import List, Optional, Dict | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import gc | |
| import psutil | |
| import threading | |
| import uuid | |
| import hashlib | |
| from enum import Enum | |
| import random | |
| import time | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| from huggingface_hub import HfApi | |
| import sys | |
| import traceback | |
| # ============================================= | |
| # INITIAL SETUP & DIAGNOSTICS | |
| # ============================================= | |
| print("=" * 60) | |
| print("π STARTING STORYBOOK GENERATOR API") | |
| print("=" * 60) | |
| print(f"Python version: {sys.version}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| # ============================================= | |
| # CREATE FASTAPI APP FIRST | |
| # ============================================= | |
| app = FastAPI(title="Storybook Generator API") | |
| # Add CORS middleware | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================= | |
| # DEFINE ALL API ROUTES FIRST (BEFORE GRADIO) | |
| # ============================================= | |
| async def test_endpoint(): | |
| """Simple test endpoint that should always work""" | |
| return { | |
| "status": "ok", | |
| "message": "Test endpoint is working", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def ping(): | |
| """Simple ping endpoint that always works""" | |
| return { | |
| "status": "alive", | |
| "timestamp": datetime.now().isoformat(), | |
| "message": "API is running" | |
| } | |
| async def debug(): | |
| """Debug endpoint showing system status""" | |
| return { | |
| "app_started": True, | |
| "python_version": sys.version, | |
| "torch_version": torch.__version__, | |
| "cuda_available": torch.cuda.is_available(), | |
| "routes": [{"path": route.path, "methods": list(route.methods)} for route in app.routes], | |
| "hf_token_set": bool(os.environ.get("HF_TOKEN")), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # ============================================= | |
| # HUGGING FACE DATASET CONFIGURATION | |
| # ============================================= | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_USERNAME = "yukee1992" | |
| DATASET_NAME = "video-project-images" | |
| DATASET_ID = f"{HF_USERNAME}/{DATASET_NAME}" | |
| print(f"π¦ HF Dataset: {DATASET_ID}") | |
| print(f"π HF Token: {'β Set' if HF_TOKEN else 'β Missing'}") | |
| # Create local directories for test images | |
| PERSISTENT_IMAGE_DIR = "generated_test_images" | |
| os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True) | |
| print(f"π Created local image directory: {PERSISTENT_IMAGE_DIR}") | |
| # Job Status Enum | |
| class JobStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| # Simple Story scene model | |
| class StoryScene(BaseModel): | |
| visual: str | |
| text: str | |
| class CharacterDescription(BaseModel): | |
| name: str | |
| description: str | |
| class StorybookRequest(BaseModel): | |
| story_title: str | |
| scenes: List[StoryScene] | |
| characters: List[CharacterDescription] = [] | |
| model_choice: str = "dreamshaper-8" | |
| style: str = "childrens_book" | |
| callback_url: Optional[str] = None | |
| consistency_seed: Optional[int] = None | |
| project_id: Optional[str] = None | |
| class JobStatusResponse(BaseModel): | |
| job_id: str | |
| status: JobStatus | |
| progress: int | |
| message: str | |
| result: Optional[dict] = None | |
| created_at: float | |
| updated_at: float | |
| class MemoryClearanceRequest(BaseModel): | |
| clear_models: bool = True | |
| clear_jobs: bool = False | |
| clear_local_images: bool = False | |
| force_gc: bool = True | |
| class MemoryStatusResponse(BaseModel): | |
| memory_used_mb: float | |
| memory_percent: float | |
| models_loaded: int | |
| active_jobs: int | |
| local_images_count: int | |
| gpu_memory_allocated_mb: Optional[float] = None | |
| gpu_memory_cached_mb: Optional[float] = None | |
| status: str | |
| # HIGH-QUALITY MODEL SELECTION - SAME AS WORKING VERSION | |
| MODEL_CHOICES = { | |
| "dreamshaper-8": "lykon/dreamshaper-8", | |
| "realistic-vision": "SG161222/Realistic_Vision_V5.1", | |
| "counterfeit": "gsdf/Counterfeit-V2.5", | |
| "pastel-mix": "andite/pastel-mix", | |
| "meina-mix": "Meina/MeinaMix", | |
| "meina-pastel": "Meina/MeinaPastel", | |
| "abyss-orange": "warriorxza/AbyssOrangeMix", | |
| "openjourney": "prompthero/openjourney", | |
| "sd-1.5": "runwayml/stable-diffusion-v1-5", | |
| } | |
| # GLOBAL STORAGE | |
| job_storage = {} | |
| model_cache = {} | |
| current_model_name = None | |
| current_pipe = None | |
| model_lock = threading.Lock() | |
| model_loading = False | |
| model_load_error = None | |
| # MEMORY MANAGEMENT FUNCTIONS - FROM WORKING VERSION | |
| def get_memory_usage(): | |
| """Get current memory usage statistics""" | |
| process = psutil.Process() | |
| memory_info = process.memory_info() | |
| memory_used_mb = memory_info.rss / (1024 * 1024) | |
| memory_percent = process.memory_percent() | |
| gpu_memory_allocated_mb = None | |
| gpu_memory_cached_mb = None | |
| if torch.cuda.is_available(): | |
| gpu_memory_allocated_mb = torch.cuda.memory_allocated() / (1024 * 1024) | |
| gpu_memory_cached_mb = torch.cuda.memory_reserved() / (1024 * 1024) | |
| return { | |
| "memory_used_mb": round(memory_used_mb, 2), | |
| "memory_percent": round(memory_percent, 2), | |
| "gpu_memory_allocated_mb": round(gpu_memory_allocated_mb, 2) if gpu_memory_allocated_mb else None, | |
| "gpu_memory_cached_mb": round(gpu_memory_cached_mb, 2) if gpu_memory_cached_mb else None, | |
| "models_loaded": len(model_cache), | |
| "active_jobs": len(job_storage), | |
| "local_images_count": len(refresh_local_images()) | |
| } | |
| def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False, force_gc=True): | |
| """Clear memory by unloading models and cleaning up resources""" | |
| results = [] | |
| if clear_models: | |
| with model_lock: | |
| models_cleared = len(model_cache) | |
| for model_name, pipe in model_cache.items(): | |
| try: | |
| if hasattr(pipe, 'to'): | |
| pipe.to('cpu') | |
| del pipe | |
| results.append(f"Unloaded model: {model_name}") | |
| except Exception as e: | |
| results.append(f"Error unloading {model_name}: {str(e)}") | |
| model_cache.clear() | |
| global current_pipe, current_model_name | |
| current_pipe = None | |
| current_model_name = None | |
| results.append(f"Cleared {models_cleared} models from cache") | |
| if clear_jobs: | |
| jobs_to_clear = [] | |
| for job_id, job_data in job_storage.items(): | |
| if job_data["status"] in [JobStatus.COMPLETED, JobStatus.FAILED]: | |
| jobs_to_clear.append(job_id) | |
| for job_id in jobs_to_clear: | |
| del job_storage[job_id] | |
| results.append(f"Cleared job: {job_id}") | |
| results.append(f"Cleared {len(jobs_to_clear)} completed/failed jobs") | |
| if clear_local_images: | |
| try: | |
| storage_info = get_local_storage_info() | |
| deleted_count = 0 | |
| if "images" in storage_info: | |
| for image_info in storage_info["images"]: | |
| success, _ = delete_local_image(image_info["path"]) | |
| if success: | |
| deleted_count += 1 | |
| results.append(f"Deleted {deleted_count} local images") | |
| except Exception as e: | |
| results.append(f"Error clearing local images: {str(e)}") | |
| if force_gc: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| results.append("GPU cache cleared") | |
| results.append("Garbage collection forced") | |
| memory_status = get_memory_usage() | |
| return { | |
| "status": "success", | |
| "actions_performed": results, | |
| "memory_after_cleanup": memory_status | |
| } | |
| # ============================================= | |
| # SIMPLIFIED MODEL LOADING - EXACTLY LIKE WORKING VERSION | |
| # ============================================= | |
| def load_model(model_name="dreamshaper-8"): | |
| """Thread-safe model loading - simplified like working version""" | |
| global model_cache, current_model_name, current_pipe, model_loading, model_load_error | |
| with model_lock: | |
| if model_name in model_cache: | |
| current_pipe = model_cache[model_name] | |
| current_model_name = model_name | |
| return current_pipe | |
| model_loading = True | |
| model_load_error = None | |
| print(f"π Loading model: {model_name}") | |
| try: | |
| model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8") | |
| # Load model - exactly like your working version | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| cache_dir="./model_cache" | |
| ) | |
| # Use the same scheduler as working version | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| # Move to CPU - like working version | |
| pipe = pipe.to("cpu") | |
| # NO additional optimizations - exactly like working version | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| current_model_name = model_name | |
| model_loading = False | |
| print(f"β Model loaded: {model_name}") | |
| return pipe | |
| except Exception as e: | |
| model_load_error = str(e) | |
| model_loading = False | |
| print(f"β Model loading failed for {model_name}: {e}") | |
| print(f"π Falling back to stable-diffusion-v1-5") | |
| try: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to("cpu") | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| current_model_name = "sd-1.5" | |
| model_loading = False | |
| print(f"β Fallback model loaded") | |
| return pipe | |
| except Exception as fallback_error: | |
| model_load_error = str(fallback_error) | |
| model_loading = False | |
| print(f"β Fallback model failed: {fallback_error}") | |
| raise | |
| # Try to load model in background thread to not block startup | |
| def load_model_background(): | |
| try: | |
| load_model("dreamshaper-8") | |
| except Exception as e: | |
| print(f"β Background model loading failed: {e}") | |
| # Start model loading in background | |
| import threading | |
| model_thread = threading.Thread(target=load_model_background) | |
| model_thread.daemon = True | |
| model_thread.start() | |
| print("β³ Model loading started in background...") | |
| # ============================================= | |
| # HF DATASET FUNCTIONS | |
| # ============================================= | |
| def ensure_dataset_exists(): | |
| """Create dataset if it doesn't exist""" | |
| if not HF_TOKEN: | |
| print("β οΈ HF_TOKEN not set, cannot create/verify dataset") | |
| return False | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| try: | |
| api.dataset_info(DATASET_ID) | |
| print(f"β Dataset {DATASET_ID} exists") | |
| except Exception: | |
| print(f"π¦ Creating dataset: {DATASET_ID}") | |
| api.create_repo( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| private=False, | |
| exist_ok=True | |
| ) | |
| print(f"β Created dataset: {DATASET_ID}") | |
| return True | |
| except Exception as e: | |
| print(f"β Failed to ensure dataset: {e}") | |
| return False | |
| def upload_to_hf_dataset(file_content, filename, subfolder=""): | |
| """Upload a file to Hugging Face Dataset""" | |
| if not HF_TOKEN: | |
| print("β οΈ HF_TOKEN not set, skipping upload") | |
| return None | |
| try: | |
| if subfolder: | |
| path_in_repo = f"data/{subfolder}/{filename}" | |
| else: | |
| path_in_repo = f"data/{filename}" | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=file_content, | |
| path_in_repo=path_in_repo, | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}" | |
| print(f"β Uploaded to HF Dataset: {url}") | |
| return url | |
| except Exception as e: | |
| print(f"β Failed to upload to HF Dataset: {e}") | |
| return None | |
| def upload_image_to_hf_dataset(image, project_id, page_number, prompt, style=""): | |
| """Upload generated image to HF Dataset""" | |
| try: | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format='PNG') | |
| img_data = img_bytes.getvalue() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip() | |
| safe_prompt = safe_prompt.replace(' ', '_') | |
| filename = f"page_{page_number:03d}_{safe_prompt}_{timestamp}.png" | |
| subfolder = f"projects/{project_id}" | |
| url = upload_to_hf_dataset(img_data, filename, subfolder) | |
| return url | |
| except Exception as e: | |
| print(f"β Failed to upload image to HF Dataset: {e}") | |
| return None | |
| # PROMPT ENGINEERING - FROM WORKING VERSION | |
| def enhance_prompt_simple(scene_visual, style="childrens_book"): | |
| """Simple prompt enhancement - uses only the provided visual prompt with style""" | |
| style_templates = { | |
| "childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration", | |
| "realistic": "photorealistic, detailed, natural lighting, professional photography", | |
| "fantasy": "fantasy art, magical, ethereal, digital painting, concept art", | |
| "anime": "anime style, Japanese animation, vibrant colors, detailed artwork" | |
| } | |
| style_prompt = style_templates.get(style, style_templates["childrens_book"]) | |
| enhanced_prompt = f"{style_prompt}, {scene_visual}" | |
| negative_prompt = ( | |
| "blurry, low quality, bad anatomy, deformed characters, " | |
| "wrong proportions, mismatched features" | |
| ) | |
| return enhanced_prompt, negative_prompt | |
| # ============================================= | |
| # IMAGE GENERATION - EXACTLY LIKE WORKING VERSION | |
| # ============================================= | |
| def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None): | |
| """Generate image - exactly like working version""" | |
| if current_pipe is None: | |
| if model_loading: | |
| raise Exception("Model is still loading. Please wait a few seconds and try again.") | |
| else: | |
| raise Exception(f"Model failed to load: {model_load_error}") | |
| enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style) | |
| if consistency_seed: | |
| scene_seed = consistency_seed + scene_number | |
| else: | |
| scene_seed = random.randint(1000, 9999) | |
| try: | |
| pipe = current_pipe | |
| # Use full quality settings like working version | |
| image = pipe( | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=35, | |
| guidance_scale=7.5, | |
| width=768, | |
| height=1024, | |
| generator=torch.Generator(device="cpu").manual_seed(scene_seed) | |
| ).images[0] | |
| print(f"β Generated image for scene {scene_number}") | |
| return image | |
| except Exception as e: | |
| print(f"β Generation failed: {str(e)}") | |
| raise | |
| # LOCAL FILE MANAGEMENT FUNCTIONS - FROM WORKING VERSION | |
| def save_image_to_local(image, prompt, style="test"): | |
| """Save image to local persistent storage""" | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_prompt = "".join(c for c in prompt[:50] if c.isalnum() or c in (' ', '-', '_')).rstrip() | |
| filename = f"image_{safe_prompt}_{timestamp}.png" | |
| style_dir = os.path.join(PERSISTENT_IMAGE_DIR, style) | |
| os.makedirs(style_dir, exist_ok=True) | |
| filepath = os.path.join(style_dir, filename) | |
| image.save(filepath) | |
| print(f"πΎ Image saved locally: {filepath}") | |
| return filepath, filename | |
| except Exception as e: | |
| print(f"β Failed to save locally: {e}") | |
| return None, None | |
| def delete_local_image(filepath): | |
| """Delete an image from local storage""" | |
| try: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| return True, f"β Deleted: {os.path.basename(filepath)}" | |
| else: | |
| return False, f"β File not found: {filepath}" | |
| except Exception as e: | |
| return False, f"β Error deleting: {str(e)}" | |
| def get_local_storage_info(): | |
| """Get information about local storage usage""" | |
| try: | |
| total_size = 0 | |
| file_count = 0 | |
| images_list = [] | |
| for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR): | |
| for file in files: | |
| if file.endswith(('.png', '.jpg', '.jpeg')): | |
| filepath = os.path.join(root, file) | |
| if os.path.exists(filepath): | |
| file_size = os.path.getsize(filepath) | |
| total_size += file_size | |
| file_count += 1 | |
| images_list.append({ | |
| 'path': filepath, | |
| 'filename': file, | |
| 'size_kb': round(file_size / 1024, 1), | |
| 'created': os.path.getctime(filepath) | |
| }) | |
| return { | |
| "total_files": file_count, | |
| "total_size_mb": round(total_size / (1024 * 1024), 2), | |
| "images": sorted(images_list, key=lambda x: x['created'], reverse=True) | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def refresh_local_images(): | |
| """Get list of all locally saved images""" | |
| try: | |
| image_files = [] | |
| for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR): | |
| for file in files: | |
| if file.endswith(('.png', '.jpg', '.jpeg')): | |
| filepath = os.path.join(root, file) | |
| if os.path.exists(filepath): | |
| image_files.append(filepath) | |
| return image_files | |
| except Exception as e: | |
| print(f"Error refreshing local images: {e}") | |
| return [] | |
| # JOB MANAGEMENT FUNCTIONS | |
| def create_job(story_request: StorybookRequest) -> str: | |
| job_id = str(uuid.uuid4()) | |
| job_storage[job_id] = { | |
| "status": JobStatus.PENDING, | |
| "progress": 0, | |
| "message": "Job created and queued", | |
| "request": story_request.dict(), | |
| "result": None, | |
| "created_at": time.time(), | |
| "updated_at": time.time(), | |
| "pages": [] | |
| } | |
| print(f"π Created job {job_id} for story: {story_request.story_title}") | |
| return job_id | |
| def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None): | |
| if job_id not in job_storage: | |
| return False | |
| job_storage[job_id].update({ | |
| "status": status, | |
| "progress": progress, | |
| "message": message, | |
| "updated_at": time.time() | |
| }) | |
| if result: | |
| job_storage[job_id]["result"] = result | |
| job_data = job_storage[job_id] | |
| request_data = job_data["request"] | |
| if request_data.get("callback_url"): | |
| try: | |
| callback_url = request_data["callback_url"] | |
| callback_data = { | |
| "job_id": job_id, | |
| "status": status.value, | |
| "progress": progress, | |
| "message": message, | |
| "story_title": request_data["story_title"], | |
| "timestamp": time.time() | |
| } | |
| if status == JobStatus.COMPLETED and result: | |
| callback_data["result"] = { | |
| "image_urls": result.get("image_urls", []), | |
| "project_id": result.get("project_id", "") | |
| } | |
| requests.post(callback_url, json=callback_data, timeout=5) | |
| print(f"π’ Callback sent to {callback_url}") | |
| except Exception as e: | |
| print(f"β οΈ Callback failed: {e}") | |
| return True | |
| def calculate_remaining_time(job_id, progress): | |
| """Calculate estimated time remaining""" | |
| if progress == 0: | |
| return "Calculating..." | |
| job_data = job_storage.get(job_id) | |
| if not job_data: | |
| return "Unknown" | |
| time_elapsed = time.time() - job_data["created_at"] | |
| if progress > 0: | |
| total_estimated = (time_elapsed / progress) * 100 | |
| remaining = total_estimated - time_elapsed | |
| return f"{int(remaining // 60)}m {int(remaining % 60)}s" | |
| return "Unknown" | |
| # BACKGROUND TASK | |
| def generate_storybook_background(job_id: str): | |
| """Background task to generate storybook""" | |
| try: | |
| if HF_TOKEN: | |
| ensure_dataset_exists() | |
| job_data = job_storage[job_id] | |
| story_request = StorybookRequest(**job_data["request"]) | |
| project_id = story_request.project_id or story_request.story_title.replace(' ', '_').lower() | |
| print(f"π¬ Starting storybook generation for job {job_id}") | |
| update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting generation...") | |
| total_scenes = len(story_request.scenes) | |
| generated_pages = [] | |
| image_urls = [] | |
| start_time = time.time() | |
| for i, scene in enumerate(story_request.scenes): | |
| progress = 5 + int(((i + 1) / total_scenes) * 90) | |
| update_job_status( | |
| job_id, | |
| JobStatus.PROCESSING, | |
| progress, | |
| f"Generating page {i+1}/{total_scenes}" | |
| ) | |
| try: | |
| # Generate image | |
| image = generate_image_simple( | |
| scene.visual, | |
| story_request.model_choice, | |
| story_request.style, | |
| i + 1, | |
| story_request.consistency_seed | |
| ) | |
| # Save locally | |
| local_filepath, local_filename = save_image_to_local(image, scene.visual, story_request.style) | |
| # Upload to HF Dataset | |
| hf_url = None | |
| if HF_TOKEN: | |
| hf_url = upload_image_to_hf_dataset( | |
| image, | |
| project_id, | |
| i + 1, | |
| scene.visual, | |
| story_request.style | |
| ) | |
| if hf_url: | |
| image_urls.append(hf_url) | |
| page_data = { | |
| "page_number": i + 1, | |
| "image_url": hf_url or f"local://{local_filepath}", | |
| "text_content": scene.text, | |
| "visual_description": scene.visual | |
| } | |
| generated_pages.append(page_data) | |
| # Clean up | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception as e: | |
| print(f"β Page {i+1} failed: {e}") | |
| update_job_status(job_id, JobStatus.FAILED, progress, str(e)) | |
| return | |
| generation_time = time.time() - start_time | |
| result = { | |
| "story_title": story_request.story_title, | |
| "project_id": project_id, | |
| "total_pages": total_scenes, | |
| "generation_time": round(generation_time, 2), | |
| "hf_dataset_url": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None, | |
| "image_urls": image_urls, | |
| "pages": generated_pages | |
| } | |
| update_job_status( | |
| job_id, | |
| JobStatus.COMPLETED, | |
| 100, | |
| f"β Completed! {len(image_urls)} images uploaded", | |
| result | |
| ) | |
| except Exception as e: | |
| error_msg = f"Generation failed: {str(e)}" | |
| print(f"β {error_msg}") | |
| traceback.print_exc() | |
| update_job_status(job_id, JobStatus.FAILED, 0, error_msg) | |
| # ============================================= | |
| # ADD ALL API ENDPOINTS HERE (BEFORE GRADIO) | |
| # ============================================= | |
| async def root(): | |
| """Root endpoint showing API status""" | |
| return { | |
| "name": "Storybook Generator API", | |
| "version": "1.0.0", | |
| "status": "running", | |
| "model_status": { | |
| "loaded": current_model_name is not None, | |
| "model_name": current_model_name, | |
| "loading": model_loading, | |
| "error": model_load_error | |
| }, | |
| "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled", | |
| "endpoints": { | |
| "test": "GET /test", | |
| "ping": "GET /ping", | |
| "debug": "GET /debug", | |
| "health": "GET /api/health", | |
| "generate": "POST /api/generate-storybook", | |
| "status": "GET /api/job-status/{job_id}", | |
| "project_images": "GET /api/project-images/{project_id}", | |
| "memory": "GET /api/memory-status", | |
| "clear_memory": "POST /api/clear-memory", | |
| "local_images": "GET /api/local-images" | |
| }, | |
| "ui": "/ui" | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "service": "storybook-generator", | |
| "model_loaded": current_model_name is not None, | |
| "model_name": current_model_name, | |
| "model_loading": model_loading, | |
| "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled", | |
| "active_jobs": len(job_storage), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def generate_storybook(request: dict, background_tasks: BackgroundTasks): | |
| """Generate a storybook from scenes""" | |
| try: | |
| print(f"π₯ Received request for: {request.get('story_title', 'Unknown')}") | |
| # Check if model is loaded | |
| if current_pipe is None: | |
| if model_loading: | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "status": "loading", | |
| "message": "Model is still loading. Please wait a few seconds and try again.", | |
| "estimated_time": "10-20 seconds" | |
| } | |
| ) | |
| else: | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "status": "error", | |
| "message": f"Model failed to load: {model_load_error}", | |
| "error": model_load_error | |
| } | |
| ) | |
| if 'consistency_seed' not in request: | |
| request['consistency_seed'] = random.randint(1000, 9999) | |
| if 'project_id' not in request: | |
| request['project_id'] = request.get('story_title', 'unknown').replace(' ', '_').lower() | |
| story_request = StorybookRequest(**request) | |
| if not story_request.story_title or not story_request.scenes: | |
| raise HTTPException(status_code=400, detail="story_title and scenes required") | |
| job_id = create_job(story_request) | |
| background_tasks.add_task(generate_storybook_background, job_id) | |
| return { | |
| "status": "success", | |
| "job_id": job_id, | |
| "story_title": story_request.story_title, | |
| "project_id": request['project_id'], | |
| "total_scenes": len(story_request.scenes), | |
| "hf_dataset": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None, | |
| "estimated_time_seconds": len(story_request.scenes) * 35 | |
| } | |
| except Exception as e: | |
| print(f"β Error in generate_storybook: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_job_status(job_id: str): | |
| """Get job status by ID""" | |
| job_data = job_storage.get(job_id) | |
| if not job_data: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return { | |
| "job_id": job_id, | |
| "status": job_data["status"].value, | |
| "progress": job_data["progress"], | |
| "message": job_data["message"], | |
| "result": job_data["result"] | |
| } | |
| async def get_project_images(project_id: str): | |
| """Get all images for a project from HF Dataset""" | |
| try: | |
| if not HF_TOKEN: | |
| return {"error": "HF_TOKEN not set"} | |
| api = HfApi(token=HF_TOKEN) | |
| files = api.list_repo_files(repo_id=DATASET_ID, repo_type="dataset") | |
| project_files = [f for f in files if f.startswith(f"data/projects/{project_id}/")] | |
| urls = [f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{f}" for f in project_files] | |
| return {"project_id": project_id, "total_images": len(urls), "image_urls": urls} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def memory_status(): | |
| """Get memory usage status""" | |
| return get_memory_usage() | |
| async def clear_memory_api(request: MemoryClearanceRequest): | |
| """Clear memory manually""" | |
| return clear_memory( | |
| clear_models=request.clear_models, | |
| clear_jobs=request.clear_jobs, | |
| clear_local_images=request.clear_local_images, | |
| force_gc=request.force_gc | |
| ) | |
| async def get_local_images(): | |
| """Get locally saved images""" | |
| return get_local_storage_info() | |
| # ============================================= | |
| # GRADIO INTERFACE (CREATED AFTER API ROUTES) | |
| # ============================================= | |
| def create_gradio_interface(): | |
| def generate_test(prompt, model_choice, style_choice): | |
| if not prompt.strip(): | |
| return None, "β Please enter a prompt" | |
| try: | |
| if current_pipe is None: | |
| if model_loading: | |
| return None, "β³ Model is still loading. Please wait a few seconds..." | |
| else: | |
| return None, f"β Model failed to load: {model_load_error}" | |
| image = generate_image_simple(prompt, model_choice, style_choice, 1) | |
| filepath, filename = save_image_to_local(image, prompt, style_choice) | |
| return image, f"β Generated! Local: {filename}" | |
| except Exception as e: | |
| return None, f"β Error: {str(e)}" | |
| with gr.Blocks(title="Storybook Generator") as demo: | |
| gr.Markdown("# π¨ Storybook Generator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model = gr.Dropdown(choices=list(MODEL_CHOICES.keys()), value="dreamshaper-8", label="Model") | |
| style = gr.Dropdown(choices=["childrens_book", "realistic", "fantasy", "anime"], value="anime", label="Style") | |
| prompt = gr.Textbox(label="Prompt", lines=3) | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output = gr.Image(label="Generated Image", height=500) | |
| status = gr.Textbox(label="Status") | |
| btn.click(fn=generate_test, inputs=[prompt, model, style], outputs=[output, status]) | |
| return demo | |
| # Create Gradio interface | |
| demo = create_gradio_interface() | |
| # ============================================= | |
| # MOUNT GRADIO (AFTER ALL API ROUTES) | |
| # ============================================= | |
| gr.mount_gradio_app(app, demo, path="/ui") | |
| # ============================================= | |
| # MAIN - RUN THE APP | |
| # ============================================= | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("π Running on Hugging Face Spaces") | |
| print(f"π¦ HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}") | |
| print("π‘ API endpoints:") | |
| print(" - GET /test") | |
| print(" - GET /ping") | |
| print(" - GET /debug") | |
| print(" - GET /") | |
| print(" - GET /api/health") | |
| print(" - POST /api/generate-storybook") | |
| print(" - GET /api/job-status/{job_id}") | |
| print(" - GET /api/project-images/{project_id}") | |
| print(" - GET /api/memory-status") | |
| print(" - POST /api/clear-memory") | |
| print(" - GET /api/local-images") | |
| print("π¨ UI: /ui") | |
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |