Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from PIL import Image | |
| import io | |
| import requests | |
| import os | |
| from datetime import datetime | |
| import time | |
| import json | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from pydantic import BaseModel | |
| import threading | |
| import uuid | |
| import random | |
| from enum import Enum | |
| import numpy as np | |
| # Try to import optional dependencies | |
| try: | |
| from rembg import remove | |
| REMBG_AVAILABLE = True | |
| except ImportError: | |
| REMBG_AVAILABLE = False | |
| print("β οΈ rembg not available, character transparency disabled") | |
| # External OCI API URL | |
| OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space" | |
| # Create local directories | |
| PERSISTENT_IMAGE_DIR = "generated_test_images" | |
| CHARACTERS_DIR = "characters" | |
| os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True) | |
| os.makedirs(CHARACTERS_DIR, exist_ok=True) | |
| print(f"π Created local directories") | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Storybook Generator API") | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class JobStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class StoryScene(BaseModel): | |
| visual: str | |
| text: str | |
| characters_present: List[str] = [] | |
| class CharacterDescription(BaseModel): | |
| name: str | |
| description: str | |
| visual_prompt: str = "" | |
| key_features: List[str] = [] | |
| class StorybookRequest(BaseModel): | |
| story_title: str | |
| scenes: List[StoryScene] | |
| characters: List[CharacterDescription] = [] | |
| model_choice: str = "sd-1.5" | |
| style: str = "childrens_book" | |
| callback_url: Optional[str] = None | |
| consistency_seed: Optional[int] = None | |
| class JobStatusResponse(BaseModel): | |
| job_id: str | |
| status: JobStatus | |
| progress: int | |
| message: str | |
| result: Optional[dict] = None | |
| created_at: float | |
| updated_at: float | |
| # Model configuration - Using smaller model for better compatibility | |
| MODEL_CONFIG = { | |
| "sd-1.5": { | |
| "model_id": "runwayml/stable-diffusion-v1-5", | |
| "revision": "fp16", | |
| "torch_dtype": torch.float16 | |
| } | |
| } | |
| job_storage = {} | |
| model_cache = {} | |
| current_pipe = None | |
| model_lock = threading.Lock() | |
| def load_model(model_name="sd-1.5"): | |
| """Load model with version compatibility""" | |
| global model_cache, current_pipe | |
| with model_lock: | |
| if model_name in model_cache: | |
| current_pipe = model_cache[model_name] | |
| return current_pipe | |
| print(f"π Loading model: {model_name}") | |
| try: | |
| config = MODEL_CONFIG[model_name] | |
| # Use simpler loading | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| config["model_id"], | |
| torch_dtype=config["torch_dtype"], | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ) | |
| # Configure scheduler | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| # Move to appropriate device | |
| if torch.cuda.is_available(): | |
| pipe = pipe.to("cuda") | |
| print("β Using CUDA") | |
| else: | |
| pipe = pipe.to("cpu") | |
| print("β Using CPU") | |
| # Enable memory efficient attention | |
| pipe.enable_attention_slicing() | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| print(f"β Model loaded successfully: {model_name}") | |
| return pipe | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| # Try fallback model | |
| try: | |
| print("π Trying fallback model...") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float32 | |
| ) | |
| pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") | |
| pipe.enable_attention_slicing() | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| print("β Fallback model loaded successfully") | |
| return pipe | |
| except Exception as fallback_error: | |
| print(f"β Fallback model also failed: {fallback_error}") | |
| raise e | |
| def generate_simple_image(prompt, negative_prompt="", seed=None, width=512, height=512): | |
| """Simple image generation with error handling""" | |
| try: | |
| pipe = load_model("sd-1.5") | |
| if pipe is None: | |
| raise Exception("Model not available") | |
| generator = None | |
| if seed: | |
| generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
| # Generate image | |
| with torch.autocast(pipe.device.type if pipe.device.type != 'mps' else 'cpu'): | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=20, | |
| guidance_scale=7.5, | |
| width=width, | |
| height=height, | |
| generator=generator | |
| ) | |
| return result.images[0] | |
| except Exception as e: | |
| print(f"β Image generation failed: {e}") | |
| # Create a simple error image | |
| error_image = Image.new('RGB', (width, height), color='red') | |
| return error_image | |
| def generate_character_image(character_desc, seed=None): | |
| """Generate character image""" | |
| try: | |
| character_prompt = f"{character_desc.visual_prompt or character_desc.description}, character design, clean lines, isolated on plain background, cartoon style, children's book illustration" | |
| negative_prompt = "blurry, low quality, complex background, multiple characters, dark, scary" | |
| image = generate_simple_image( | |
| character_prompt, | |
| negative_prompt, | |
| seed, | |
| width=512, | |
| height=512 | |
| ) | |
| # If rembg is available, remove background | |
| if REMBG_AVAILABLE: | |
| try: | |
| image = remove(image) | |
| except Exception as bg_error: | |
| print(f"β οΈ Background removal failed: {bg_error}") | |
| # Convert to RGBA anyway | |
| image = image.convert('RGBA') | |
| else: | |
| image = image.convert('RGBA') | |
| return image | |
| except Exception as e: | |
| print(f"β Character generation failed: {e}") | |
| error_image = Image.new('RGBA', (512, 512), (255, 0, 0, 128)) | |
| return error_image | |
| def save_to_oci_bucket(file_data, filename, story_title, file_type="image"): | |
| """Save files to OCI bucket with fallback""" | |
| try: | |
| api_url = f"{OCI_API_BASE_URL}/api/upload" | |
| full_subfolder = f'stories/{story_title}' | |
| mime_type = "image/png" if file_type == "image" else "text/plain" | |
| files = {'file': (filename, file_data, mime_type)} | |
| data = { | |
| 'project_id': 'storybook-library', | |
| 'subfolder': full_subfolder | |
| } | |
| response = requests.post(api_url, files=files, data=data, timeout=30) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if result['status'] == 'success': | |
| return result.get('file_url', 'Unknown URL') | |
| else: | |
| print(f"β οΈ OCI API Error: {result.get('message', 'Unknown error')}") | |
| return f"local://{filename}" | |
| else: | |
| print(f"β οΈ HTTP Error: {response.status_code}") | |
| return f"local://{filename}" | |
| except Exception as e: | |
| print(f"β οΈ OCI upload failed, using local fallback: {str(e)}") | |
| return f"local://{filename}" | |
| def create_job(story_request: StorybookRequest) -> str: | |
| job_id = str(uuid.uuid4()) | |
| job_storage[job_id] = { | |
| "status": JobStatus.PENDING, | |
| "progress": 0, | |
| "message": "Job created and queued", | |
| "request": story_request.dict(), | |
| "result": None, | |
| "created_at": time.time(), | |
| "updated_at": time.time(), | |
| } | |
| print(f"π Created job {job_id} for story: {story_request.story_title}") | |
| return job_id | |
| def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None): | |
| if job_id not in job_storage: | |
| return False | |
| job_storage[job_id].update({ | |
| "status": status, | |
| "progress": progress, | |
| "message": message, | |
| "updated_at": time.time() | |
| }) | |
| if result: | |
| job_storage[job_id]["result"] = result | |
| return True | |
| def generate_storybook_background(job_id: str): | |
| """Background task for storybook generation""" | |
| try: | |
| job_data = job_storage[job_id] | |
| story_request_data = job_data["request"] | |
| story_request = StorybookRequest(**story_request_data) | |
| print(f"π¬ Starting storybook generation for job {job_id}") | |
| update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting generation...") | |
| # Generate characters first | |
| character_urls = {} | |
| if story_request.characters: | |
| update_job_status(job_id, JobStatus.PROCESSING, 10, "Generating characters...") | |
| for i, character in enumerate(story_request.characters): | |
| progress = 10 + int((i / len(story_request.characters)) * 30) | |
| update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating character: {character.name}") | |
| try: | |
| print(f"π€ Generating character: {character.name}") | |
| character_image = generate_character_image( | |
| character, | |
| story_request.consistency_seed | |
| ) | |
| # Save character locally | |
| char_filename = f"character_{character.name}_{job_id}.png" | |
| char_local_path = os.path.join(CHARACTERS_DIR, char_filename) | |
| character_image.save(char_local_path, 'PNG') | |
| # Upload to OCI | |
| img_bytes = io.BytesIO() | |
| character_image.save(img_bytes, format='PNG') | |
| character_url = save_to_oci_bucket( | |
| img_bytes.getvalue(), | |
| f"character_{character.name}.png", | |
| story_request.story_title, | |
| "image" | |
| ) | |
| character_urls[character.name] = { | |
| "url": character_url, | |
| "local_path": char_local_path | |
| } | |
| print(f"β Character {character.name} completed") | |
| except Exception as e: | |
| error_msg = f"Failed to generate character {character.name}: {str(e)}" | |
| print(f"β {error_msg}") | |
| character_urls[character.name] = {"url": f"error_{character.name}", "local_path": ""} | |
| # Generate scenes | |
| update_job_status(job_id, JobStatus.PROCESSING, 40, "Generating scenes...") | |
| generated_pages = [] | |
| for i, scene in enumerate(story_request.scenes): | |
| progress = 40 + int((i / len(story_request.scenes)) * 55) | |
| update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating scene {i+1}/{len(story_request.scenes)}...") | |
| try: | |
| print(f"πΌοΈ Generating scene {i+1}") | |
| # Enhanced scene prompt with character context | |
| character_context = "" | |
| if scene.characters_present: | |
| character_context = f" featuring {', '.join(scene.characters_present)}" | |
| scene_prompt = f"children's book illustration, {scene.visual}{character_context}, colorful, clean, professional artwork" | |
| negative_prompt = "blurry, low quality, bad anatomy, dark, scary" | |
| scene_image = generate_simple_image( | |
| scene_prompt, | |
| negative_prompt, | |
| story_request.consistency_seed | |
| ) | |
| # Save scene locally | |
| scene_filename = f"scene_{i+1:03d}_{job_id}.png" | |
| scene_local_path = os.path.join(PERSISTENT_IMAGE_DIR, scene_filename) | |
| scene_image.save(scene_local_path, 'PNG') | |
| # Upload to OCI | |
| img_bytes = io.BytesIO() | |
| scene_image.save(img_bytes, format='PNG') | |
| scene_url = save_to_oci_bucket( | |
| img_bytes.getvalue(), | |
| f"scene_{i+1:03d}.png", | |
| story_request.story_title, | |
| "image" | |
| ) | |
| page_data = { | |
| "page_number": i + 1, | |
| "image_url": scene_url, | |
| "local_path": scene_local_path, | |
| "text": scene.text, | |
| "characters_present": scene.characters_present | |
| } | |
| generated_pages.append(page_data) | |
| print(f"β Scene {i+1} completed") | |
| except Exception as e: | |
| error_msg = f"Failed to generate scene {i+1}: {str(e)}" | |
| print(f"β {error_msg}") | |
| page_data = { | |
| "page_number": i + 1, | |
| "image_url": f"error_scene_{i+1}", | |
| "local_path": "", | |
| "text": scene.text, | |
| "characters_present": scene.characters_present, | |
| "error": error_msg | |
| } | |
| generated_pages.append(page_data) | |
| # Final result | |
| result = { | |
| "story_title": story_request.story_title, | |
| "total_pages": len(generated_pages), | |
| "total_characters": len(character_urls), | |
| "characters": character_urls, | |
| "pages": generated_pages, | |
| "job_id": job_id, | |
| "rembg_available": REMBG_AVAILABLE | |
| } | |
| update_job_status( | |
| job_id, | |
| JobStatus.COMPLETED, | |
| 100, | |
| f"π Storybook completed! {len(generated_pages)} scenes and {len(character_urls)} characters generated.", | |
| result | |
| ) | |
| print(f"π Storybook finished for job {job_id}") | |
| except Exception as e: | |
| error_msg = f"Story generation failed: {str(e)}" | |
| print(f"β {error_msg}") | |
| update_job_status(job_id, JobStatus.FAILED, 0, error_msg) | |
| # API Routes | |
| async def generate_storybook(request: dict, background_tasks: BackgroundTasks): | |
| """Storybook generation endpoint""" | |
| try: | |
| print(f"π₯ Received storybook request: {request.get('story_title', 'Unknown')}") | |
| # Set default seed if not provided | |
| if 'consistency_seed' not in request or not request['consistency_seed']: | |
| request['consistency_seed'] = random.randint(1000, 9999) | |
| story_request = StorybookRequest(**request) | |
| if not story_request.story_title or not story_request.scenes: | |
| raise HTTPException(status_code=400, detail="story_title and scenes are required") | |
| job_id = create_job(story_request) | |
| background_tasks.add_task(generate_storybook_background, job_id) | |
| return { | |
| "status": "success", | |
| "message": "Storybook generation started", | |
| "job_id": job_id, | |
| "story_title": story_request.story_title, | |
| "total_scenes": len(story_request.scenes), | |
| "total_characters": len(story_request.characters), | |
| "consistency_seed": story_request.consistency_seed, | |
| "rembg_available": REMBG_AVAILABLE | |
| } | |
| except Exception as e: | |
| error_msg = f"API Error: {str(e)}" | |
| print(f"β {error_msg}") | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| async def get_job_status(job_id: str): | |
| job_data = job_storage.get(job_id) | |
| if not job_data: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return JobStatusResponse( | |
| job_id=job_id, | |
| status=job_data["status"], | |
| progress=job_data["progress"], | |
| message=job_data["message"], | |
| result=job_data["result"], | |
| created_at=job_data["created_at"], | |
| updated_at=job_data["updated_at"] | |
| ) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "service": "storybook-generator", | |
| "timestamp": datetime.now().isoformat(), | |
| "active_jobs": len(job_storage), | |
| "model_loaded": "sd-1.5" in model_cache, | |
| "rembg_available": REMBG_AVAILABLE | |
| } | |
| async def root(): | |
| return {"message": "Storybook Generator API", "status": "running"} | |
| # Simple Gradio Interface | |
| def create_test_interface(): | |
| with gr.Blocks(title="Storybook Generator Test") as demo: | |
| gr.Markdown("# π¨ Storybook Generator Test") | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_prompt = gr.Textbox( | |
| label="Test Prompt", | |
| value="a cute cartoon cat reading a book under a tree", | |
| lines=2 | |
| ) | |
| test_seed = gr.Number(label="Seed", value=42) | |
| generate_btn = gr.Button("Generate Test Image", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", height=512) | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| def test_generate(prompt, seed): | |
| try: | |
| status_text = "π Generating image..." | |
| image = generate_simple_image(prompt, seed=seed) | |
| status_text = "β Image generated successfully!" | |
| return image, status_text | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| generate_btn.click( | |
| test_generate, | |
| inputs=[test_prompt, test_seed], | |
| outputs=[output_image, status_text] | |
| ) | |
| return demo | |
| # Initialize the app | |
| print("π Initializing Storybook Generator...") | |
| print(f"π¦ rembg available: {REMBG_AVAILABLE}") | |
| try: | |
| # Test model loading | |
| load_model("sd-1.5") | |
| print("β Model loaded successfully!") | |
| except Exception as e: | |
| print(f"β Model loading failed: {e}") | |
| demo = create_test_interface() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |