yukee1992's picture
Update app.py
79c85c7 verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
import io
import requests
import os
from datetime import datetime
import time
import json
from typing import List, Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
import threading
import uuid
import random
from enum import Enum
import numpy as np
# Try to import optional dependencies
try:
from rembg import remove
REMBG_AVAILABLE = True
except ImportError:
REMBG_AVAILABLE = False
print("⚠️ rembg not available, character transparency disabled")
# External OCI API URL
OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space"
# Create local directories
PERSISTENT_IMAGE_DIR = "generated_test_images"
CHARACTERS_DIR = "characters"
os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True)
os.makedirs(CHARACTERS_DIR, exist_ok=True)
print(f"πŸ“ Created local directories")
# Initialize FastAPI app
app = FastAPI(title="Storybook Generator API")
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class JobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class StoryScene(BaseModel):
visual: str
text: str
characters_present: List[str] = []
class CharacterDescription(BaseModel):
name: str
description: str
visual_prompt: str = ""
key_features: List[str] = []
class StorybookRequest(BaseModel):
story_title: str
scenes: List[StoryScene]
characters: List[CharacterDescription] = []
model_choice: str = "sd-1.5"
style: str = "childrens_book"
callback_url: Optional[str] = None
consistency_seed: Optional[int] = None
class JobStatusResponse(BaseModel):
job_id: str
status: JobStatus
progress: int
message: str
result: Optional[dict] = None
created_at: float
updated_at: float
# Model configuration - Using smaller model for better compatibility
MODEL_CONFIG = {
"sd-1.5": {
"model_id": "runwayml/stable-diffusion-v1-5",
"revision": "fp16",
"torch_dtype": torch.float16
}
}
job_storage = {}
model_cache = {}
current_pipe = None
model_lock = threading.Lock()
def load_model(model_name="sd-1.5"):
"""Load model with version compatibility"""
global model_cache, current_pipe
with model_lock:
if model_name in model_cache:
current_pipe = model_cache[model_name]
return current_pipe
print(f"πŸ”„ Loading model: {model_name}")
try:
config = MODEL_CONFIG[model_name]
# Use simpler loading
pipe = StableDiffusionPipeline.from_pretrained(
config["model_id"],
torch_dtype=config["torch_dtype"],
safety_checker=None,
requires_safety_checker=False
)
# Configure scheduler
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# Move to appropriate device
if torch.cuda.is_available():
pipe = pipe.to("cuda")
print("βœ… Using CUDA")
else:
pipe = pipe.to("cpu")
print("βœ… Using CPU")
# Enable memory efficient attention
pipe.enable_attention_slicing()
model_cache[model_name] = pipe
current_pipe = pipe
print(f"βœ… Model loaded successfully: {model_name}")
return pipe
except Exception as e:
print(f"❌ Model loading failed: {e}")
# Try fallback model
try:
print("πŸ”„ Trying fallback model...")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float32
)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
pipe.enable_attention_slicing()
model_cache[model_name] = pipe
current_pipe = pipe
print("βœ… Fallback model loaded successfully")
return pipe
except Exception as fallback_error:
print(f"❌ Fallback model also failed: {fallback_error}")
raise e
def generate_simple_image(prompt, negative_prompt="", seed=None, width=512, height=512):
"""Simple image generation with error handling"""
try:
pipe = load_model("sd-1.5")
if pipe is None:
raise Exception("Model not available")
generator = None
if seed:
generator = torch.Generator(device=pipe.device).manual_seed(seed)
# Generate image
with torch.autocast(pipe.device.type if pipe.device.type != 'mps' else 'cpu'):
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=20,
guidance_scale=7.5,
width=width,
height=height,
generator=generator
)
return result.images[0]
except Exception as e:
print(f"❌ Image generation failed: {e}")
# Create a simple error image
error_image = Image.new('RGB', (width, height), color='red')
return error_image
def generate_character_image(character_desc, seed=None):
"""Generate character image"""
try:
character_prompt = f"{character_desc.visual_prompt or character_desc.description}, character design, clean lines, isolated on plain background, cartoon style, children's book illustration"
negative_prompt = "blurry, low quality, complex background, multiple characters, dark, scary"
image = generate_simple_image(
character_prompt,
negative_prompt,
seed,
width=512,
height=512
)
# If rembg is available, remove background
if REMBG_AVAILABLE:
try:
image = remove(image)
except Exception as bg_error:
print(f"⚠️ Background removal failed: {bg_error}")
# Convert to RGBA anyway
image = image.convert('RGBA')
else:
image = image.convert('RGBA')
return image
except Exception as e:
print(f"❌ Character generation failed: {e}")
error_image = Image.new('RGBA', (512, 512), (255, 0, 0, 128))
return error_image
def save_to_oci_bucket(file_data, filename, story_title, file_type="image"):
"""Save files to OCI bucket with fallback"""
try:
api_url = f"{OCI_API_BASE_URL}/api/upload"
full_subfolder = f'stories/{story_title}'
mime_type = "image/png" if file_type == "image" else "text/plain"
files = {'file': (filename, file_data, mime_type)}
data = {
'project_id': 'storybook-library',
'subfolder': full_subfolder
}
response = requests.post(api_url, files=files, data=data, timeout=30)
if response.status_code == 200:
result = response.json()
if result['status'] == 'success':
return result.get('file_url', 'Unknown URL')
else:
print(f"⚠️ OCI API Error: {result.get('message', 'Unknown error')}")
return f"local://{filename}"
else:
print(f"⚠️ HTTP Error: {response.status_code}")
return f"local://{filename}"
except Exception as e:
print(f"⚠️ OCI upload failed, using local fallback: {str(e)}")
return f"local://{filename}"
def create_job(story_request: StorybookRequest) -> str:
job_id = str(uuid.uuid4())
job_storage[job_id] = {
"status": JobStatus.PENDING,
"progress": 0,
"message": "Job created and queued",
"request": story_request.dict(),
"result": None,
"created_at": time.time(),
"updated_at": time.time(),
}
print(f"πŸ“ Created job {job_id} for story: {story_request.story_title}")
return job_id
def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None):
if job_id not in job_storage:
return False
job_storage[job_id].update({
"status": status,
"progress": progress,
"message": message,
"updated_at": time.time()
})
if result:
job_storage[job_id]["result"] = result
return True
def generate_storybook_background(job_id: str):
"""Background task for storybook generation"""
try:
job_data = job_storage[job_id]
story_request_data = job_data["request"]
story_request = StorybookRequest(**story_request_data)
print(f"🎬 Starting storybook generation for job {job_id}")
update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting generation...")
# Generate characters first
character_urls = {}
if story_request.characters:
update_job_status(job_id, JobStatus.PROCESSING, 10, "Generating characters...")
for i, character in enumerate(story_request.characters):
progress = 10 + int((i / len(story_request.characters)) * 30)
update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating character: {character.name}")
try:
print(f"πŸ‘€ Generating character: {character.name}")
character_image = generate_character_image(
character,
story_request.consistency_seed
)
# Save character locally
char_filename = f"character_{character.name}_{job_id}.png"
char_local_path = os.path.join(CHARACTERS_DIR, char_filename)
character_image.save(char_local_path, 'PNG')
# Upload to OCI
img_bytes = io.BytesIO()
character_image.save(img_bytes, format='PNG')
character_url = save_to_oci_bucket(
img_bytes.getvalue(),
f"character_{character.name}.png",
story_request.story_title,
"image"
)
character_urls[character.name] = {
"url": character_url,
"local_path": char_local_path
}
print(f"βœ… Character {character.name} completed")
except Exception as e:
error_msg = f"Failed to generate character {character.name}: {str(e)}"
print(f"❌ {error_msg}")
character_urls[character.name] = {"url": f"error_{character.name}", "local_path": ""}
# Generate scenes
update_job_status(job_id, JobStatus.PROCESSING, 40, "Generating scenes...")
generated_pages = []
for i, scene in enumerate(story_request.scenes):
progress = 40 + int((i / len(story_request.scenes)) * 55)
update_job_status(job_id, JobStatus.PROCESSING, progress, f"Generating scene {i+1}/{len(story_request.scenes)}...")
try:
print(f"πŸ–ΌοΈ Generating scene {i+1}")
# Enhanced scene prompt with character context
character_context = ""
if scene.characters_present:
character_context = f" featuring {', '.join(scene.characters_present)}"
scene_prompt = f"children's book illustration, {scene.visual}{character_context}, colorful, clean, professional artwork"
negative_prompt = "blurry, low quality, bad anatomy, dark, scary"
scene_image = generate_simple_image(
scene_prompt,
negative_prompt,
story_request.consistency_seed
)
# Save scene locally
scene_filename = f"scene_{i+1:03d}_{job_id}.png"
scene_local_path = os.path.join(PERSISTENT_IMAGE_DIR, scene_filename)
scene_image.save(scene_local_path, 'PNG')
# Upload to OCI
img_bytes = io.BytesIO()
scene_image.save(img_bytes, format='PNG')
scene_url = save_to_oci_bucket(
img_bytes.getvalue(),
f"scene_{i+1:03d}.png",
story_request.story_title,
"image"
)
page_data = {
"page_number": i + 1,
"image_url": scene_url,
"local_path": scene_local_path,
"text": scene.text,
"characters_present": scene.characters_present
}
generated_pages.append(page_data)
print(f"βœ… Scene {i+1} completed")
except Exception as e:
error_msg = f"Failed to generate scene {i+1}: {str(e)}"
print(f"❌ {error_msg}")
page_data = {
"page_number": i + 1,
"image_url": f"error_scene_{i+1}",
"local_path": "",
"text": scene.text,
"characters_present": scene.characters_present,
"error": error_msg
}
generated_pages.append(page_data)
# Final result
result = {
"story_title": story_request.story_title,
"total_pages": len(generated_pages),
"total_characters": len(character_urls),
"characters": character_urls,
"pages": generated_pages,
"job_id": job_id,
"rembg_available": REMBG_AVAILABLE
}
update_job_status(
job_id,
JobStatus.COMPLETED,
100,
f"πŸŽ‰ Storybook completed! {len(generated_pages)} scenes and {len(character_urls)} characters generated.",
result
)
print(f"πŸŽ‰ Storybook finished for job {job_id}")
except Exception as e:
error_msg = f"Story generation failed: {str(e)}"
print(f"❌ {error_msg}")
update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
# API Routes
@app.post("/api/generate-storybook")
async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
"""Storybook generation endpoint"""
try:
print(f"πŸ“₯ Received storybook request: {request.get('story_title', 'Unknown')}")
# Set default seed if not provided
if 'consistency_seed' not in request or not request['consistency_seed']:
request['consistency_seed'] = random.randint(1000, 9999)
story_request = StorybookRequest(**request)
if not story_request.story_title or not story_request.scenes:
raise HTTPException(status_code=400, detail="story_title and scenes are required")
job_id = create_job(story_request)
background_tasks.add_task(generate_storybook_background, job_id)
return {
"status": "success",
"message": "Storybook generation started",
"job_id": job_id,
"story_title": story_request.story_title,
"total_scenes": len(story_request.scenes),
"total_characters": len(story_request.characters),
"consistency_seed": story_request.consistency_seed,
"rembg_available": REMBG_AVAILABLE
}
except Exception as e:
error_msg = f"API Error: {str(e)}"
print(f"❌ {error_msg}")
raise HTTPException(status_code=500, detail=error_msg)
@app.get("/api/job-status/{job_id}")
async def get_job_status(job_id: str):
job_data = job_storage.get(job_id)
if not job_data:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatusResponse(
job_id=job_id,
status=job_data["status"],
progress=job_data["progress"],
message=job_data["message"],
result=job_data["result"],
created_at=job_data["created_at"],
updated_at=job_data["updated_at"]
)
@app.get("/api/health")
async def health_check():
return {
"status": "healthy",
"service": "storybook-generator",
"timestamp": datetime.now().isoformat(),
"active_jobs": len(job_storage),
"model_loaded": "sd-1.5" in model_cache,
"rembg_available": REMBG_AVAILABLE
}
@app.get("/")
async def root():
return {"message": "Storybook Generator API", "status": "running"}
# Simple Gradio Interface
def create_test_interface():
with gr.Blocks(title="Storybook Generator Test") as demo:
gr.Markdown("# 🎨 Storybook Generator Test")
with gr.Row():
with gr.Column():
test_prompt = gr.Textbox(
label="Test Prompt",
value="a cute cartoon cat reading a book under a tree",
lines=2
)
test_seed = gr.Number(label="Seed", value=42)
generate_btn = gr.Button("Generate Test Image", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image", height=512)
status_text = gr.Textbox(label="Status", interactive=False)
def test_generate(prompt, seed):
try:
status_text = "πŸ”„ Generating image..."
image = generate_simple_image(prompt, seed=seed)
status_text = "βœ… Image generated successfully!"
return image, status_text
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
generate_btn.click(
test_generate,
inputs=[test_prompt, test_seed],
outputs=[output_image, status_text]
)
return demo
# Initialize the app
print("πŸš€ Initializing Storybook Generator...")
print(f"πŸ“¦ rembg available: {REMBG_AVAILABLE}")
try:
# Test model loading
load_model("sd-1.5")
print("βœ… Model loaded successfully!")
except Exception as e:
print(f"❌ Model loading failed: {e}")
demo = create_test_interface()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)