import gradio as gr import torch from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler from PIL import Image import io import requests import os from datetime import datetime import time import json from typing import List, Optional from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel import threading import uuid import random from enum import Enum import numpy as np # Try to import optional dependencies try: from rembg import remove REMBG_AVAILABLE = True except ImportError: REMBG_AVAILABLE = False print("⚠️ rembg not available, character transparency disabled") # External OCI API URL OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space" # Create local directories PERSISTENT_IMAGE_DIR = "generated_test_images" CHARACTERS_DIR = "characters" os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True) os.makedirs(CHARACTERS_DIR, exist_ok=True) print(f"📁 Created local directories") # Initialize FastAPI app app = FastAPI(title="Storybook Generator API") from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class JobStatus(str, Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" class StoryScene(BaseModel): visual: str text: str characters_present: List[str] = [] class CharacterDescription(BaseModel): name: str description: str visual_prompt: str = "" key_features: List[str] = [] class StorybookRequest(BaseModel): story_title: str scenes: List[StoryScene] characters: List[CharacterDescription] = [] model_choice: str = "sd-1.5" style: str = "childrens_book" callback_url: Optional[str] = None consistency_seed: Optional[int] = None class JobStatusResponse(BaseModel): job_id: str status: JobStatus progress: int message: str result: Optional[dict] = None created_at: float updated_at: float # Model configuration - Using smaller model for better compatibility MODEL_CONFIG = { "sd-1.5": { "model_id": "runwayml/stable-diffusion-v1-5", "revision": "fp16", "torch_dtype": torch.float16 } } job_storage = {} model_cache = {} current_pipe = None model_lock = threading.Lock() def load_model(model_name="sd-1.5"): """Load model with version compatibility""" global model_cache, current_pipe with model_lock: if model_name in model_cache: current_pipe = model_cache[model_name] return current_pipe print(f"🔄 Loading model: {model_name}") try: config = MODEL_CONFIG[model_name] # Use simpler loading pipe = StableDiffusionPipeline.from_pretrained( config["model_id"], torch_dtype=config["torch_dtype"], safety_checker=None, requires_safety_checker=False ) # Configure scheduler pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) # Move to appropriate device if torch.cuda.is_available(): pipe = pipe.to("cuda") print("✅ Using CUDA") else: pipe = pipe.to("cpu") print("✅ Using CPU") # Enable memory efficient attention pipe.enable_attention_slicing() model_cache[model_name] = pipe current_pipe = pipe print(f"✅ Model loaded successfully: {model_name}") return pipe except Exception as e: print(f"❌ Model loading failed: {e}") # Try fallback model try: print("🔄 Trying fallback model...") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32 ) pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") pipe.enable_attention_slicing() model_cache[model_name] = pipe current_pipe = pipe print("✅ Fallback model loaded successfully") return pipe except Exception as fallback_error: print(f"❌ Fallback model also failed: {fallback_error}") raise e def generate_simple_image(prompt, negative_prompt="", seed=None, width=512, height=512): """Simple image generation with error handling""" try: pipe = load_model("sd-1.5") if pipe is None: raise Exception("Model not available") generator = None if seed: generator = torch.Generator(device=pipe.device).manual_seed(seed) # Generate image with torch.autocast(pipe.device.type if pipe.device.type != 'mps' else 'cpu'): result = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, guidance_scale=7.5, width=width, height=height, generator=generator ) return result.images[0] except Exception as e: print(f"❌ Image generation failed: {e}") # Create a simple error image error_image = Image.new('RGB', (width, height), color='red') return error_image def generate_character_image(character_desc, seed=None): """Generate character image""" try: character_prompt = f"{character_desc.visual_prompt or character_desc.description}, character design, clean lines, isolated on plain background, cartoon style, children's book illustration" negative_prompt = "blurry, low quality, complex background, multiple characters, dark, scary" image = generate_simple_image( character_prompt, negative_prompt, seed, width=512, height=512 ) # If rembg is available, remove background if REMBG_AVAILABLE: try: image = remove(image) except Exception as bg_error: print(f"⚠️ Background removal failed: {bg_error}") # Convert to RGBA anyway image = image.convert('RGBA') else: image = image.convert('RGBA') return image except Exception as e: print(f"❌ Character generation failed: {e}") error_image = Image.new('RGBA', (512, 512), (255, 0, 0, 128)) return error_image def save_to_oci_bucket(file_data, filename, story_title, file_type="image"): """Save files to OCI bucket with fallback""" try: api_url = f"{OCI_API_BASE_URL}/api/upload" full_subfolder = f'stories/{story_title}' mime_type = "image/png" if file_type == "image" else "text/plain" files = {'file': (filename, file_data, mime_type)} data = { 'project_id': 'storybook-library', 'subfolder': full_subfolder } response = requests.post(api_url, files=files, data=data, timeout=30) if response.status_code == 200: result = response.json() if result['status'] == 'success': return result.get('file_url', 'Unknown URL') else: print(f"⚠️ OCI API Error: {result.get('message', 'Unknown error')}") return f"local://{filename}" else: print(f"⚠️ HTTP Error: {response.status_code}") return f"local://{filename}" except Exception as e: print(f"⚠️ OCI upload failed, using local fallback: {str(e)}") return f"local://{filename}" def create_job(story_request: StorybookRequest) -> str: job_id = str(uuid.uuid4()) job_storage[job_id] = { "status": JobStatus.PENDING, "progress": 0, "message": "Job created and queued", "request": story_request.dict(), "result": None, "created_at": time.time(), "updated_at": time.time(), } print(f"📝 Created job {job_id} for story: {story_request.story_title}") return job_id def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None): if job_id not in job_storage: return False job_storage[job_id].update({ "status": status, "progress": progress, "message": message, "updated_at": time.time() }) if result: job_storage[job_id]["result"] = result return True def generate_storybook_background(job_id: str): """Background task for storybook generation""" try: job_data = job_storage[job_id] story_request_data = job_data["request"] story_request = StorybookRequest(**story_request_data) print(f"🎬 Starting storybook generation for job {job_id}") update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting generation...") # Generate characters first character_urls = {} if story_request.characters: update_job_status(job_id, JobStatus.PROCESSING, 10, "Generating characters...") for i, character in enumerate(story_request.characters): progress = 10 + int((i / len(story_request.characters)) * 30) update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating character: {character.name}") try: print(f"👤 Generating character: {character.name}") character_image = generate_character_image( character, story_request.consistency_seed ) # Save character locally char_filename = f"character_{character.name}_{job_id}.png" char_local_path = os.path.join(CHARACTERS_DIR, char_filename) character_image.save(char_local_path, 'PNG') # Upload to OCI img_bytes = io.BytesIO() character_image.save(img_bytes, format='PNG') character_url = save_to_oci_bucket( img_bytes.getvalue(), f"character_{character.name}.png", story_request.story_title, "image" ) character_urls[character.name] = { "url": character_url, "local_path": char_local_path } print(f"✅ Character {character.name} completed") except Exception as e: error_msg = f"Failed to generate character {character.name}: {str(e)}" print(f"❌ {error_msg}") character_urls[character.name] = {"url": f"error_{character.name}", "local_path": ""} # Generate scenes update_job_status(job_id, JobStatus.PROCESSING, 40, "Generating scenes...") generated_pages = [] for i, scene in enumerate(story_request.scenes): progress = 40 + int((i / len(story_request.scenes)) * 55) update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating scene {i+1}/{len(story_request.scenes)}...") try: print(f"🖼️ Generating scene {i+1}") # Enhanced scene prompt with character context character_context = "" if scene.characters_present: character_context = f" featuring {', '.join(scene.characters_present)}" scene_prompt = f"children's book illustration, {scene.visual}{character_context}, colorful, clean, professional artwork" negative_prompt = "blurry, low quality, bad anatomy, dark, scary" scene_image = generate_simple_image( scene_prompt, negative_prompt, story_request.consistency_seed ) # Save scene locally scene_filename = f"scene_{i+1:03d}_{job_id}.png" scene_local_path = os.path.join(PERSISTENT_IMAGE_DIR, scene_filename) scene_image.save(scene_local_path, 'PNG') # Upload to OCI img_bytes = io.BytesIO() scene_image.save(img_bytes, format='PNG') scene_url = save_to_oci_bucket( img_bytes.getvalue(), f"scene_{i+1:03d}.png", story_request.story_title, "image" ) page_data = { "page_number": i + 1, "image_url": scene_url, "local_path": scene_local_path, "text": scene.text, "characters_present": scene.characters_present } generated_pages.append(page_data) print(f"✅ Scene {i+1} completed") except Exception as e: error_msg = f"Failed to generate scene {i+1}: {str(e)}" print(f"❌ {error_msg}") page_data = { "page_number": i + 1, "image_url": f"error_scene_{i+1}", "local_path": "", "text": scene.text, "characters_present": scene.characters_present, "error": error_msg } generated_pages.append(page_data) # Final result result = { "story_title": story_request.story_title, "total_pages": len(generated_pages), "total_characters": len(character_urls), "characters": character_urls, "pages": generated_pages, "job_id": job_id, "rembg_available": REMBG_AVAILABLE } update_job_status( job_id, JobStatus.COMPLETED, 100, f"🎉 Storybook completed! {len(generated_pages)} scenes and {len(character_urls)} characters generated.", result ) print(f"🎉 Storybook finished for job {job_id}") except Exception as e: error_msg = f"Story generation failed: {str(e)}" print(f"❌ {error_msg}") update_job_status(job_id, JobStatus.FAILED, 0, error_msg) # API Routes @app.post("/api/generate-storybook") async def generate_storybook(request: dict, background_tasks: BackgroundTasks): """Storybook generation endpoint""" try: print(f"📥 Received storybook request: {request.get('story_title', 'Unknown')}") # Set default seed if not provided if 'consistency_seed' not in request or not request['consistency_seed']: request['consistency_seed'] = random.randint(1000, 9999) story_request = StorybookRequest(**request) if not story_request.story_title or not story_request.scenes: raise HTTPException(status_code=400, detail="story_title and scenes are required") job_id = create_job(story_request) background_tasks.add_task(generate_storybook_background, job_id) return { "status": "success", "message": "Storybook generation started", "job_id": job_id, "story_title": story_request.story_title, "total_scenes": len(story_request.scenes), "total_characters": len(story_request.characters), "consistency_seed": story_request.consistency_seed, "rembg_available": REMBG_AVAILABLE } except Exception as e: error_msg = f"API Error: {str(e)}" print(f"❌ {error_msg}") raise HTTPException(status_code=500, detail=error_msg) @app.get("/api/job-status/{job_id}") async def get_job_status(job_id: str): job_data = job_storage.get(job_id) if not job_data: raise HTTPException(status_code=404, detail="Job not found") return JobStatusResponse( job_id=job_id, status=job_data["status"], progress=job_data["progress"], message=job_data["message"], result=job_data["result"], created_at=job_data["created_at"], updated_at=job_data["updated_at"] ) @app.get("/api/health") async def health_check(): return { "status": "healthy", "service": "storybook-generator", "timestamp": datetime.now().isoformat(), "active_jobs": len(job_storage), "model_loaded": "sd-1.5" in model_cache, "rembg_available": REMBG_AVAILABLE } @app.get("/") async def root(): return {"message": "Storybook Generator API", "status": "running"} # Simple Gradio Interface def create_test_interface(): with gr.Blocks(title="Storybook Generator Test") as demo: gr.Markdown("# 🎨 Storybook Generator Test") with gr.Row(): with gr.Column(): test_prompt = gr.Textbox( label="Test Prompt", value="a cute cartoon cat reading a book under a tree", lines=2 ) test_seed = gr.Number(label="Seed", value=42) generate_btn = gr.Button("Generate Test Image", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Image", height=512) status_text = gr.Textbox(label="Status", interactive=False) def test_generate(prompt, seed): try: status_text = "🔄 Generating image..." image = generate_simple_image(prompt, seed=seed) status_text = "✅ Image generated successfully!" return image, status_text except Exception as e: error_msg = f"❌ Error: {str(e)}" print(error_msg) return None, error_msg generate_btn.click( test_generate, inputs=[test_prompt, test_seed], outputs=[output_image, status_text] ) return demo # Initialize the app print("🚀 Initializing Storybook Generator...") print(f"📦 rembg available: {REMBG_AVAILABLE}") try: # Test model loading load_model("sd-1.5") print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Model loading failed: {e}") demo = create_test_interface() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)