Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,8 @@ import re
|
|
| 10 |
import time
|
| 11 |
import json
|
| 12 |
from typing import List, Optional, Dict
|
| 13 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
|
| 14 |
from pydantic import BaseModel
|
| 15 |
import gc
|
| 16 |
import psutil
|
|
@@ -23,7 +24,54 @@ import time
|
|
| 23 |
from requests.adapters import HTTPAdapter
|
| 24 |
from urllib3.util.retry import Retry
|
| 25 |
from huggingface_hub import HfApi
|
| 26 |
-
import accelerate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# =============================================
|
| 29 |
# MEMORY OPTIMIZATION SETTINGS
|
|
@@ -57,19 +105,6 @@ PERSISTENT_IMAGE_DIR = "generated_test_images"
|
|
| 57 |
os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True)
|
| 58 |
print(f"π Created local image directory: {PERSISTENT_IMAGE_DIR}")
|
| 59 |
|
| 60 |
-
# Initialize FastAPI app
|
| 61 |
-
app = FastAPI(title="Storybook Generator API")
|
| 62 |
-
|
| 63 |
-
# Add CORS middleware
|
| 64 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 65 |
-
app.add_middleware(
|
| 66 |
-
CORSMiddleware,
|
| 67 |
-
allow_origins=["*"],
|
| 68 |
-
allow_credentials=True,
|
| 69 |
-
allow_methods=["*"],
|
| 70 |
-
allow_headers=["*"],
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
# Job Status Enum
|
| 74 |
class JobStatus(str, Enum):
|
| 75 |
PENDING = "pending"
|
|
@@ -140,6 +175,8 @@ model_cache = {}
|
|
| 140 |
current_model_name = None
|
| 141 |
current_pipe = None
|
| 142 |
model_lock = threading.Lock()
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# MEMORY MANAGEMENT FUNCTIONS
|
| 145 |
def get_memory_usage():
|
|
@@ -230,11 +267,11 @@ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False,
|
|
| 230 |
}
|
| 231 |
|
| 232 |
# =============================================
|
| 233 |
-
#
|
| 234 |
# =============================================
|
| 235 |
def load_model(model_name="dreamshaper-8"):
|
| 236 |
-
"""Thread-safe model loading with
|
| 237 |
-
global model_cache, current_model_name, current_pipe
|
| 238 |
|
| 239 |
with model_lock:
|
| 240 |
if model_name in model_cache:
|
|
@@ -242,26 +279,32 @@ def load_model(model_name="dreamshaper-8"):
|
|
| 242 |
current_model_name = model_name
|
| 243 |
return current_pipe
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
print(f"π Loading model: {model_name}")
|
| 246 |
try:
|
| 247 |
model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
|
| 248 |
|
| 249 |
-
# Load with memory optimizations
|
| 250 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 251 |
model_id,
|
| 252 |
torch_dtype=torch.float32,
|
| 253 |
safety_checker=None,
|
| 254 |
requires_safety_checker=False,
|
| 255 |
cache_dir="./model_cache",
|
| 256 |
-
low_cpu_mem_usage=True,
|
| 257 |
-
use_safetensors=True
|
| 258 |
-
variant="fp32" # Use full precision for quality
|
| 259 |
)
|
| 260 |
|
| 261 |
-
# Use memory efficient scheduler
|
| 262 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 263 |
|
| 264 |
-
# Enable attention slicing
|
| 265 |
pipe.enable_attention_slicing()
|
| 266 |
|
| 267 |
# Enable sequential CPU offload if needed
|
|
@@ -273,11 +316,14 @@ def load_model(model_name="dreamshaper-8"):
|
|
| 273 |
model_cache[model_name] = pipe
|
| 274 |
current_pipe = pipe
|
| 275 |
current_model_name = model_name
|
|
|
|
| 276 |
|
| 277 |
print(f"β
Model loaded: {model_name}")
|
| 278 |
return pipe
|
| 279 |
|
| 280 |
except Exception as e:
|
|
|
|
|
|
|
| 281 |
print(f"β Model loading failed for {model_name}: {e}")
|
| 282 |
print(f"π Falling back to stable-diffusion-v1-5")
|
| 283 |
|
|
@@ -295,18 +341,30 @@ def load_model(model_name="dreamshaper-8"):
|
|
| 295 |
model_cache[model_name] = pipe
|
| 296 |
current_pipe = pipe
|
| 297 |
current_model_name = "sd-1.5"
|
|
|
|
| 298 |
|
| 299 |
print(f"β
Fallback model loaded")
|
| 300 |
return pipe
|
| 301 |
|
| 302 |
except Exception as fallback_error:
|
|
|
|
|
|
|
| 303 |
print(f"β Fallback model failed: {fallback_error}")
|
| 304 |
raise
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# =============================================
|
| 312 |
# HF DATASET FUNCTIONS
|
|
@@ -387,7 +445,7 @@ def upload_image_to_hf_dataset(image, project_id, page_number, prompt, style="")
|
|
| 387 |
|
| 388 |
# PROMPT ENGINEERING
|
| 389 |
def enhance_prompt_simple(scene_visual, style="childrens_book"):
|
| 390 |
-
"""Simple prompt enhancement
|
| 391 |
|
| 392 |
style_templates = {
|
| 393 |
"childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration",
|
|
@@ -408,10 +466,13 @@ def enhance_prompt_simple(scene_visual, style="childrens_book"):
|
|
| 408 |
return enhanced_prompt, negative_prompt
|
| 409 |
|
| 410 |
# =============================================
|
| 411 |
-
#
|
| 412 |
# =============================================
|
| 413 |
def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None):
|
| 414 |
-
"""Generate image with
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style)
|
| 417 |
|
|
@@ -421,21 +482,19 @@ def generate_image_simple(prompt, model_choice, style, scene_number, consistency
|
|
| 421 |
scene_seed = random.randint(1000, 9999)
|
| 422 |
|
| 423 |
try:
|
| 424 |
-
pipe =
|
| 425 |
|
| 426 |
-
|
| 427 |
-
with torch.inference_mode(): # More memory efficient than no_grad
|
| 428 |
image = pipe(
|
| 429 |
prompt=enhanced_prompt,
|
| 430 |
negative_prompt=negative_prompt,
|
| 431 |
-
num_inference_steps=
|
| 432 |
guidance_scale=7.5,
|
| 433 |
-
width=
|
| 434 |
-
height=
|
| 435 |
generator=torch.Generator(device="cpu").manual_seed(scene_seed)
|
| 436 |
).images[0]
|
| 437 |
|
| 438 |
-
# Clean up after generation
|
| 439 |
if torch.cuda.is_available():
|
| 440 |
torch.cuda.empty_cache()
|
| 441 |
|
|
@@ -555,6 +614,9 @@ def update_job_status(job_id: str, status: JobStatus, progress: int, message: st
|
|
| 555 |
if result:
|
| 556 |
job_storage[job_id]["result"] = result
|
| 557 |
|
|
|
|
|
|
|
|
|
|
| 558 |
if request_data.get("callback_url"):
|
| 559 |
try:
|
| 560 |
callback_url = request_data["callback_url"]
|
|
@@ -597,7 +659,7 @@ def calculate_remaining_time(job_id, progress):
|
|
| 597 |
|
| 598 |
return "Unknown"
|
| 599 |
|
| 600 |
-
#
|
| 601 |
def generate_storybook_background(job_id: str):
|
| 602 |
"""Background task with memory optimization"""
|
| 603 |
try:
|
|
@@ -663,7 +725,7 @@ def generate_storybook_background(job_id: str):
|
|
| 663 |
}
|
| 664 |
generated_pages.append(page_data)
|
| 665 |
|
| 666 |
-
# Clean up
|
| 667 |
if torch.cuda.is_available():
|
| 668 |
torch.cuda.empty_cache()
|
| 669 |
gc.collect()
|
|
@@ -696,14 +758,87 @@ def generate_storybook_background(job_id: str):
|
|
| 696 |
except Exception as e:
|
| 697 |
error_msg = f"Generation failed: {str(e)}"
|
| 698 |
print(f"β {error_msg}")
|
|
|
|
| 699 |
update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
|
| 700 |
|
|
|
|
| 701 |
# FASTAPI ENDPOINTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
@app.post("/api/generate-storybook")
|
| 703 |
async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
|
|
|
|
| 704 |
try:
|
| 705 |
print(f"π₯ Received request for: {request.get('story_title', 'Unknown')}")
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
if 'consistency_seed' not in request:
|
| 708 |
request['consistency_seed'] = random.randint(1000, 9999)
|
| 709 |
|
|
@@ -729,10 +864,13 @@ async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
|
|
| 729 |
}
|
| 730 |
|
| 731 |
except Exception as e:
|
|
|
|
|
|
|
| 732 |
raise HTTPException(status_code=500, detail=str(e))
|
| 733 |
|
| 734 |
@app.get("/api/job-status/{job_id}")
|
| 735 |
async def get_job_status(job_id: str):
|
|
|
|
| 736 |
job_data = job_storage.get(job_id)
|
| 737 |
if not job_data:
|
| 738 |
raise HTTPException(status_code=404, detail="Job not found")
|
|
@@ -745,17 +883,9 @@ async def get_job_status(job_id: str):
|
|
| 745 |
"result": job_data["result"]
|
| 746 |
}
|
| 747 |
|
| 748 |
-
@app.get("/api/health")
|
| 749 |
-
async def health():
|
| 750 |
-
return {
|
| 751 |
-
"status": "healthy",
|
| 752 |
-
"service": "storybook-generator",
|
| 753 |
-
"hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
|
| 754 |
-
"active_jobs": len(job_storage)
|
| 755 |
-
}
|
| 756 |
-
|
| 757 |
@app.get("/api/project-images/{project_id}")
|
| 758 |
async def get_project_images(project_id: str):
|
|
|
|
| 759 |
try:
|
| 760 |
if not HF_TOKEN:
|
| 761 |
return {"error": "HF_TOKEN not set"}
|
|
@@ -770,16 +900,47 @@ async def get_project_images(project_id: str):
|
|
| 770 |
except Exception as e:
|
| 771 |
return {"error": str(e)}
|
| 772 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
# GRADIO INTERFACE
|
|
|
|
| 774 |
def create_gradio_interface():
|
| 775 |
def generate_test(prompt, model_choice, style_choice):
|
| 776 |
if not prompt.strip():
|
| 777 |
return None, "β Please enter a prompt"
|
| 778 |
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
|
| 784 |
with gr.Blocks(title="Storybook Generator") as demo:
|
| 785 |
gr.Markdown("# π¨ Storybook Generator")
|
|
@@ -801,27 +962,31 @@ def create_gradio_interface():
|
|
| 801 |
|
| 802 |
demo = create_gradio_interface()
|
| 803 |
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
"message": "Storybook Generator API",
|
| 808 |
-
"hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
|
| 809 |
-
"endpoints": {
|
| 810 |
-
"generate": "POST /api/generate-storybook",
|
| 811 |
-
"status": "GET /api/job-status/{job_id}",
|
| 812 |
-
"health": "GET /api/health",
|
| 813 |
-
"project_images": "GET /api/project-images/{project_id}"
|
| 814 |
-
},
|
| 815 |
-
"ui": "/ui"
|
| 816 |
-
}
|
| 817 |
-
|
| 818 |
if __name__ == "__main__":
|
| 819 |
import uvicorn
|
| 820 |
|
| 821 |
if os.environ.get('SPACE_ID'):
|
| 822 |
print("π Running on Hugging Face Spaces")
|
| 823 |
print(f"π¦ HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
gr.mount_gradio_app(app, demo, path="/ui")
|
|
|
|
|
|
|
| 825 |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|
| 826 |
else:
|
| 827 |
print("π Running locally")
|
|
|
|
| 10 |
import time
|
| 11 |
import json
|
| 12 |
from typing import List, Optional, Dict
|
| 13 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
| 14 |
+
from fastapi.responses import JSONResponse
|
| 15 |
from pydantic import BaseModel
|
| 16 |
import gc
|
| 17 |
import psutil
|
|
|
|
| 24 |
from requests.adapters import HTTPAdapter
|
| 25 |
from urllib3.util.retry import Retry
|
| 26 |
from huggingface_hub import HfApi
|
| 27 |
+
import accelerate
|
| 28 |
+
import sys
|
| 29 |
+
import traceback
|
| 30 |
+
|
| 31 |
+
# =============================================
|
| 32 |
+
# INITIAL SETUP & DIAGNOSTICS
|
| 33 |
+
# =============================================
|
| 34 |
+
print("=" * 60)
|
| 35 |
+
print("π STARTING STORYBOOK GENERATOR API")
|
| 36 |
+
print("=" * 60)
|
| 37 |
+
print(f"Python version: {sys.version}")
|
| 38 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 39 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 40 |
+
|
| 41 |
+
# Create diagnostic endpoint that works immediately
|
| 42 |
+
app = FastAPI(title="Storybook Generator API")
|
| 43 |
+
|
| 44 |
+
@app.get("/ping")
|
| 45 |
+
async def ping():
|
| 46 |
+
"""Simple ping endpoint that always works"""
|
| 47 |
+
return {
|
| 48 |
+
"status": "alive",
|
| 49 |
+
"timestamp": datetime.now().isoformat(),
|
| 50 |
+
"message": "Basic endpoint is working"
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
@app.get("/debug")
|
| 54 |
+
async def debug():
|
| 55 |
+
"""Debug endpoint showing system status"""
|
| 56 |
+
return {
|
| 57 |
+
"app_started": True,
|
| 58 |
+
"python_version": sys.version,
|
| 59 |
+
"torch_version": torch.__version__,
|
| 60 |
+
"cuda_available": torch.cuda.is_available(),
|
| 61 |
+
"routes": [{"path": route.path, "methods": list(route.methods)} for route in app.routes],
|
| 62 |
+
"hf_token_set": bool(os.environ.get("HF_TOKEN")),
|
| 63 |
+
"timestamp": datetime.now().isoformat()
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Add CORS middleware
|
| 67 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 68 |
+
app.add_middleware(
|
| 69 |
+
CORSMiddleware,
|
| 70 |
+
allow_origins=["*"],
|
| 71 |
+
allow_credentials=True,
|
| 72 |
+
allow_methods=["*"],
|
| 73 |
+
allow_headers=["*"],
|
| 74 |
+
)
|
| 75 |
|
| 76 |
# =============================================
|
| 77 |
# MEMORY OPTIMIZATION SETTINGS
|
|
|
|
| 105 |
os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True)
|
| 106 |
print(f"π Created local image directory: {PERSISTENT_IMAGE_DIR}")
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
# Job Status Enum
|
| 109 |
class JobStatus(str, Enum):
|
| 110 |
PENDING = "pending"
|
|
|
|
| 175 |
current_model_name = None
|
| 176 |
current_pipe = None
|
| 177 |
model_lock = threading.Lock()
|
| 178 |
+
model_loading = False
|
| 179 |
+
model_load_error = None
|
| 180 |
|
| 181 |
# MEMORY MANAGEMENT FUNCTIONS
|
| 182 |
def get_memory_usage():
|
|
|
|
| 267 |
}
|
| 268 |
|
| 269 |
# =============================================
|
| 270 |
+
# MODEL LOADING WITH PROPER ERROR HANDLING
|
| 271 |
# =============================================
|
| 272 |
def load_model(model_name="dreamshaper-8"):
|
| 273 |
+
"""Thread-safe model loading with error handling"""
|
| 274 |
+
global model_cache, current_model_name, current_pipe, model_loading, model_load_error
|
| 275 |
|
| 276 |
with model_lock:
|
| 277 |
if model_name in model_cache:
|
|
|
|
| 279 |
current_model_name = model_name
|
| 280 |
return current_pipe
|
| 281 |
|
| 282 |
+
if model_loading:
|
| 283 |
+
print(f"β³ Model already loading, waiting...")
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
model_loading = True
|
| 287 |
+
model_load_error = None
|
| 288 |
+
|
| 289 |
print(f"π Loading model: {model_name}")
|
| 290 |
try:
|
| 291 |
model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
|
| 292 |
|
| 293 |
+
# Load with memory optimizations
|
| 294 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 295 |
model_id,
|
| 296 |
torch_dtype=torch.float32,
|
| 297 |
safety_checker=None,
|
| 298 |
requires_safety_checker=False,
|
| 299 |
cache_dir="./model_cache",
|
| 300 |
+
low_cpu_mem_usage=True,
|
| 301 |
+
use_safetensors=True
|
|
|
|
| 302 |
)
|
| 303 |
|
| 304 |
+
# Use memory efficient scheduler
|
| 305 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 306 |
|
| 307 |
+
# Enable attention slicing
|
| 308 |
pipe.enable_attention_slicing()
|
| 309 |
|
| 310 |
# Enable sequential CPU offload if needed
|
|
|
|
| 316 |
model_cache[model_name] = pipe
|
| 317 |
current_pipe = pipe
|
| 318 |
current_model_name = model_name
|
| 319 |
+
model_loading = False
|
| 320 |
|
| 321 |
print(f"β
Model loaded: {model_name}")
|
| 322 |
return pipe
|
| 323 |
|
| 324 |
except Exception as e:
|
| 325 |
+
model_load_error = str(e)
|
| 326 |
+
model_loading = False
|
| 327 |
print(f"β Model loading failed for {model_name}: {e}")
|
| 328 |
print(f"π Falling back to stable-diffusion-v1-5")
|
| 329 |
|
|
|
|
| 341 |
model_cache[model_name] = pipe
|
| 342 |
current_pipe = pipe
|
| 343 |
current_model_name = "sd-1.5"
|
| 344 |
+
model_loading = False
|
| 345 |
|
| 346 |
print(f"β
Fallback model loaded")
|
| 347 |
return pipe
|
| 348 |
|
| 349 |
except Exception as fallback_error:
|
| 350 |
+
model_load_error = str(fallback_error)
|
| 351 |
+
model_loading = False
|
| 352 |
print(f"β Fallback model failed: {fallback_error}")
|
| 353 |
raise
|
| 354 |
|
| 355 |
+
# Try to load model in background thread to not block startup
|
| 356 |
+
def load_model_background():
|
| 357 |
+
try:
|
| 358 |
+
load_model("dreamshaper-8")
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"β Background model loading failed: {e}")
|
| 361 |
+
|
| 362 |
+
# Start model loading in background
|
| 363 |
+
import threading
|
| 364 |
+
model_thread = threading.Thread(target=load_model_background)
|
| 365 |
+
model_thread.daemon = True
|
| 366 |
+
model_thread.start()
|
| 367 |
+
print("β³ Model loading started in background...")
|
| 368 |
|
| 369 |
# =============================================
|
| 370 |
# HF DATASET FUNCTIONS
|
|
|
|
| 445 |
|
| 446 |
# PROMPT ENGINEERING
|
| 447 |
def enhance_prompt_simple(scene_visual, style="childrens_book"):
|
| 448 |
+
"""Simple prompt enhancement"""
|
| 449 |
|
| 450 |
style_templates = {
|
| 451 |
"childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration",
|
|
|
|
| 466 |
return enhanced_prompt, negative_prompt
|
| 467 |
|
| 468 |
# =============================================
|
| 469 |
+
# IMAGE GENERATION
|
| 470 |
# =============================================
|
| 471 |
def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None):
|
| 472 |
+
"""Generate image with error handling"""
|
| 473 |
+
|
| 474 |
+
if current_pipe is None:
|
| 475 |
+
raise Exception("Model not loaded yet. Please wait a few seconds and try again.")
|
| 476 |
|
| 477 |
enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style)
|
| 478 |
|
|
|
|
| 482 |
scene_seed = random.randint(1000, 9999)
|
| 483 |
|
| 484 |
try:
|
| 485 |
+
pipe = current_pipe
|
| 486 |
|
| 487 |
+
with torch.inference_mode():
|
|
|
|
| 488 |
image = pipe(
|
| 489 |
prompt=enhanced_prompt,
|
| 490 |
negative_prompt=negative_prompt,
|
| 491 |
+
num_inference_steps=25, # Reduced for speed
|
| 492 |
guidance_scale=7.5,
|
| 493 |
+
width=512, # Reduced for memory
|
| 494 |
+
height=512, # Reduced for memory
|
| 495 |
generator=torch.Generator(device="cpu").manual_seed(scene_seed)
|
| 496 |
).images[0]
|
| 497 |
|
|
|
|
| 498 |
if torch.cuda.is_available():
|
| 499 |
torch.cuda.empty_cache()
|
| 500 |
|
|
|
|
| 614 |
if result:
|
| 615 |
job_storage[job_id]["result"] = result
|
| 616 |
|
| 617 |
+
job_data = job_storage[job_id]
|
| 618 |
+
request_data = job_data["request"]
|
| 619 |
+
|
| 620 |
if request_data.get("callback_url"):
|
| 621 |
try:
|
| 622 |
callback_url = request_data["callback_url"]
|
|
|
|
| 659 |
|
| 660 |
return "Unknown"
|
| 661 |
|
| 662 |
+
# BACKGROUND TASK
|
| 663 |
def generate_storybook_background(job_id: str):
|
| 664 |
"""Background task with memory optimization"""
|
| 665 |
try:
|
|
|
|
| 725 |
}
|
| 726 |
generated_pages.append(page_data)
|
| 727 |
|
| 728 |
+
# Clean up
|
| 729 |
if torch.cuda.is_available():
|
| 730 |
torch.cuda.empty_cache()
|
| 731 |
gc.collect()
|
|
|
|
| 758 |
except Exception as e:
|
| 759 |
error_msg = f"Generation failed: {str(e)}"
|
| 760 |
print(f"β {error_msg}")
|
| 761 |
+
traceback.print_exc()
|
| 762 |
update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
|
| 763 |
|
| 764 |
+
# =============================================
|
| 765 |
# FASTAPI ENDPOINTS
|
| 766 |
+
# =============================================
|
| 767 |
+
|
| 768 |
+
@app.get("/")
|
| 769 |
+
async def root():
|
| 770 |
+
"""Root endpoint showing API status"""
|
| 771 |
+
return {
|
| 772 |
+
"name": "Storybook Generator API",
|
| 773 |
+
"version": "1.0.0",
|
| 774 |
+
"status": "running",
|
| 775 |
+
"model_status": {
|
| 776 |
+
"loaded": current_model_name is not None,
|
| 777 |
+
"model_name": current_model_name,
|
| 778 |
+
"loading": model_loading,
|
| 779 |
+
"error": model_load_error
|
| 780 |
+
},
|
| 781 |
+
"hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
|
| 782 |
+
"endpoints": {
|
| 783 |
+
"ping": "GET /ping",
|
| 784 |
+
"debug": "GET /debug",
|
| 785 |
+
"health": "GET /api/health",
|
| 786 |
+
"generate": "POST /api/generate-storybook",
|
| 787 |
+
"status": "GET /api/job-status/{job_id}",
|
| 788 |
+
"project_images": "GET /api/project-images/{project_id}",
|
| 789 |
+
"memory": "GET /api/memory-status",
|
| 790 |
+
"clear_memory": "POST /api/clear-memory",
|
| 791 |
+
"local_images": "GET /api/local-images"
|
| 792 |
+
},
|
| 793 |
+
"ui": "/ui",
|
| 794 |
+
"test_commands": {
|
| 795 |
+
"ping": "curl -X GET https://yukee1992-Video_image_generator.hf.space/ping",
|
| 796 |
+
"health": "curl -X GET https://yukee1992-Video_image_generator.hf.space/api/health",
|
| 797 |
+
"generate": "curl -X POST https://yukee1992-Video_image_generator.hf.space/api/generate-storybook -H 'Content-Type: application/json' -d '{\"story_title\":\"test\",\"scenes\":[{\"visual\":\"a cat\",\"text\":\"test\"}]}'"
|
| 798 |
+
}
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
@app.get("/api/health")
|
| 802 |
+
async def health():
|
| 803 |
+
"""Health check endpoint"""
|
| 804 |
+
return {
|
| 805 |
+
"status": "healthy",
|
| 806 |
+
"service": "storybook-generator",
|
| 807 |
+
"model_loaded": current_model_name is not None,
|
| 808 |
+
"model_name": current_model_name,
|
| 809 |
+
"model_loading": model_loading,
|
| 810 |
+
"hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
|
| 811 |
+
"active_jobs": len(job_storage),
|
| 812 |
+
"timestamp": datetime.now().isoformat()
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
@app.post("/api/generate-storybook")
|
| 816 |
async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
|
| 817 |
+
"""Generate a storybook from scenes"""
|
| 818 |
try:
|
| 819 |
print(f"π₯ Received request for: {request.get('story_title', 'Unknown')}")
|
| 820 |
|
| 821 |
+
# Check if model is loaded
|
| 822 |
+
if current_pipe is None:
|
| 823 |
+
if model_loading:
|
| 824 |
+
return JSONResponse(
|
| 825 |
+
status_code=503,
|
| 826 |
+
content={
|
| 827 |
+
"status": "loading",
|
| 828 |
+
"message": "Model is still loading. Please wait a few seconds and try again.",
|
| 829 |
+
"estimated_time": "10-20 seconds"
|
| 830 |
+
}
|
| 831 |
+
)
|
| 832 |
+
else:
|
| 833 |
+
return JSONResponse(
|
| 834 |
+
status_code=503,
|
| 835 |
+
content={
|
| 836 |
+
"status": "error",
|
| 837 |
+
"message": f"Model failed to load: {model_load_error}",
|
| 838 |
+
"error": model_load_error
|
| 839 |
+
}
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
if 'consistency_seed' not in request:
|
| 843 |
request['consistency_seed'] = random.randint(1000, 9999)
|
| 844 |
|
|
|
|
| 864 |
}
|
| 865 |
|
| 866 |
except Exception as e:
|
| 867 |
+
print(f"β Error in generate_storybook: {e}")
|
| 868 |
+
traceback.print_exc()
|
| 869 |
raise HTTPException(status_code=500, detail=str(e))
|
| 870 |
|
| 871 |
@app.get("/api/job-status/{job_id}")
|
| 872 |
async def get_job_status(job_id: str):
|
| 873 |
+
"""Get job status by ID"""
|
| 874 |
job_data = job_storage.get(job_id)
|
| 875 |
if not job_data:
|
| 876 |
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
| 883 |
"result": job_data["result"]
|
| 884 |
}
|
| 885 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
@app.get("/api/project-images/{project_id}")
|
| 887 |
async def get_project_images(project_id: str):
|
| 888 |
+
"""Get all images for a project from HF Dataset"""
|
| 889 |
try:
|
| 890 |
if not HF_TOKEN:
|
| 891 |
return {"error": "HF_TOKEN not set"}
|
|
|
|
| 900 |
except Exception as e:
|
| 901 |
return {"error": str(e)}
|
| 902 |
|
| 903 |
+
@app.get("/api/memory-status")
|
| 904 |
+
async def memory_status():
|
| 905 |
+
"""Get memory usage status"""
|
| 906 |
+
return get_memory_usage()
|
| 907 |
+
|
| 908 |
+
@app.post("/api/clear-memory")
|
| 909 |
+
async def clear_memory_api(request: MemoryClearanceRequest):
|
| 910 |
+
"""Clear memory manually"""
|
| 911 |
+
return clear_memory(
|
| 912 |
+
clear_models=request.clear_models,
|
| 913 |
+
clear_jobs=request.clear_jobs,
|
| 914 |
+
clear_local_images=request.clear_local_images,
|
| 915 |
+
force_gc=request.force_gc
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
@app.get("/api/local-images")
|
| 919 |
+
async def get_local_images():
|
| 920 |
+
"""Get locally saved images"""
|
| 921 |
+
return get_local_storage_info()
|
| 922 |
+
|
| 923 |
+
# =============================================
|
| 924 |
# GRADIO INTERFACE
|
| 925 |
+
# =============================================
|
| 926 |
def create_gradio_interface():
|
| 927 |
def generate_test(prompt, model_choice, style_choice):
|
| 928 |
if not prompt.strip():
|
| 929 |
return None, "β Please enter a prompt"
|
| 930 |
|
| 931 |
+
try:
|
| 932 |
+
if current_pipe is None:
|
| 933 |
+
if model_loading:
|
| 934 |
+
return None, "β³ Model is still loading. Please wait a few seconds..."
|
| 935 |
+
else:
|
| 936 |
+
return None, f"β Model failed to load: {model_load_error}"
|
| 937 |
+
|
| 938 |
+
image = generate_image_simple(prompt, model_choice, style_choice, 1)
|
| 939 |
+
filepath, filename = save_image_to_local(image, prompt, style_choice)
|
| 940 |
+
|
| 941 |
+
return image, f"β
Generated! Local: {filename}"
|
| 942 |
+
except Exception as e:
|
| 943 |
+
return None, f"β Error: {str(e)}"
|
| 944 |
|
| 945 |
with gr.Blocks(title="Storybook Generator") as demo:
|
| 946 |
gr.Markdown("# π¨ Storybook Generator")
|
|
|
|
| 962 |
|
| 963 |
demo = create_gradio_interface()
|
| 964 |
|
| 965 |
+
# =============================================
|
| 966 |
+
# MAIN
|
| 967 |
+
# =============================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
if __name__ == "__main__":
|
| 969 |
import uvicorn
|
| 970 |
|
| 971 |
if os.environ.get('SPACE_ID'):
|
| 972 |
print("π Running on Hugging Face Spaces")
|
| 973 |
print(f"π¦ HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
|
| 974 |
+
print("π‘ API endpoints:")
|
| 975 |
+
print(" - GET /ping")
|
| 976 |
+
print(" - GET /debug")
|
| 977 |
+
print(" - GET /api/health")
|
| 978 |
+
print(" - POST /api/generate-storybook")
|
| 979 |
+
print(" - GET /api/job-status/{{job_id}}")
|
| 980 |
+
print(" - GET /api/project-images/{{project_id}}")
|
| 981 |
+
print(" - GET /api/memory-status")
|
| 982 |
+
print(" - POST /api/clear-memory")
|
| 983 |
+
print(" - GET /api/local-images")
|
| 984 |
+
print("π¨ UI: /ui")
|
| 985 |
+
|
| 986 |
+
# Mount Gradio
|
| 987 |
gr.mount_gradio_app(app, demo, path="/ui")
|
| 988 |
+
|
| 989 |
+
# Run the combined app
|
| 990 |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|
| 991 |
else:
|
| 992 |
print("π Running locally")
|