Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from PIL import Image | |
| import io | |
| import requests | |
| import os | |
| from datetime import datetime | |
| import re | |
| import time | |
| import json | |
| from typing import List, Optional, Dict | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from pydantic import BaseModel | |
| import gc | |
| import psutil | |
| import threading | |
| import uuid | |
| import hashlib | |
| from enum import Enum | |
| import random | |
| import time | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| # External OCI API URL - YOUR BUCKET SAVING API | |
| OCI_API_BASE_URL = "https://yukee1992-oci-story-book.hf.space" | |
| # Create local directories for test images | |
| PERSISTENT_IMAGE_DIR = "generated_test_images" | |
| os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True) | |
| print(f"📁 Created local image directory: {PERSISTENT_IMAGE_DIR}") | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Storybook Generator API") | |
| # Add CORS middleware | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Job Status Enum | |
| class JobStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| # Simple Story scene model | |
| class StoryScene(BaseModel): | |
| visual: str | |
| text: str | |
| class CharacterDescription(BaseModel): | |
| name: str | |
| description: str | |
| class StorybookRequest(BaseModel): | |
| story_title: str | |
| scenes: List[StoryScene] | |
| characters: List[CharacterDescription] = [] | |
| model_choice: str = "dreamshaper-8" | |
| style: str = "childrens_book" | |
| callback_url: Optional[str] = None | |
| consistency_seed: Optional[int] = None | |
| class JobStatusResponse(BaseModel): | |
| job_id: str | |
| status: JobStatus | |
| progress: int | |
| message: str | |
| result: Optional[dict] = None | |
| created_at: float | |
| updated_at: float | |
| class MemoryClearanceRequest(BaseModel): | |
| clear_models: bool = True | |
| clear_jobs: bool = False | |
| clear_local_images: bool = False | |
| force_gc: bool = True | |
| class MemoryStatusResponse(BaseModel): | |
| memory_used_mb: float | |
| memory_percent: float | |
| models_loaded: int | |
| active_jobs: int | |
| local_images_count: int | |
| gpu_memory_allocated_mb: Optional[float] = None | |
| gpu_memory_cached_mb: Optional[float] = None | |
| status: str | |
| # HIGH-QUALITY MODEL SELECTION - ANIME FOCUSED & WORKING | |
| MODEL_CHOICES = { | |
| "dreamshaper-8": "lykon/dreamshaper-8", | |
| "realistic-vision": "SG161222/Realistic_Vision_V5.1", | |
| "counterfeit": "gsdf/Counterfeit-V2.5", | |
| "pastel-mix": "andite/pastel-mix", | |
| "meina-mix": "Meina/MeinaMix", | |
| "meina-pastel": "Meina/MeinaPastel", | |
| "abyss-orange": "warriorxza/AbyssOrangeMix", | |
| "openjourney": "prompthero/openjourney", | |
| "sd-1.5": "runwayml/stable-diffusion-v1-5", | |
| } | |
| # GLOBAL STORAGE | |
| job_storage = {} | |
| model_cache = {} | |
| current_model_name = None | |
| current_pipe = None | |
| model_lock = threading.Lock() | |
| # MEMORY MANAGEMENT FUNCTIONS | |
| def get_memory_usage(): | |
| """Get current memory usage statistics""" | |
| process = psutil.Process() | |
| memory_info = process.memory_info() | |
| memory_used_mb = memory_info.rss / (1024 * 1024) | |
| memory_percent = process.memory_percent() | |
| # GPU memory if available | |
| gpu_memory_allocated_mb = None | |
| gpu_memory_cached_mb = None | |
| if torch.cuda.is_available(): | |
| gpu_memory_allocated_mb = torch.cuda.memory_allocated() / (1024 * 1024) | |
| gpu_memory_cached_mb = torch.cuda.memory_reserved() / (1024 * 1024) | |
| return { | |
| "memory_used_mb": round(memory_used_mb, 2), | |
| "memory_percent": round(memory_percent, 2), | |
| "gpu_memory_allocated_mb": round(gpu_memory_allocated_mb, 2) if gpu_memory_allocated_mb else None, | |
| "gpu_memory_cached_mb": round(gpu_memory_cached_mb, 2) if gpu_memory_cached_mb else None, | |
| "models_loaded": len(model_cache), | |
| "active_jobs": len(job_storage), | |
| "local_images_count": len(refresh_local_images()) | |
| } | |
| def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False, force_gc=True): | |
| """Clear memory by unloading models and cleaning up resources""" | |
| results = [] | |
| # Clear model cache | |
| if clear_models: | |
| with model_lock: | |
| models_cleared = len(model_cache) | |
| for model_name, pipe in model_cache.items(): | |
| try: | |
| # Move to CPU first if it's on GPU | |
| if hasattr(pipe, 'to'): | |
| pipe.to('cpu') | |
| # Delete the pipeline | |
| del pipe | |
| results.append(f"Unloaded model: {model_name}") | |
| except Exception as e: | |
| results.append(f"Error unloading {model_name}: {str(e)}") | |
| model_cache.clear() | |
| global current_pipe, current_model_name | |
| current_pipe = None | |
| current_model_name = None | |
| results.append(f"Cleared {models_cleared} models from cache") | |
| # Clear completed jobs | |
| if clear_jobs: | |
| jobs_to_clear = [] | |
| for job_id, job_data in job_storage.items(): | |
| if job_data["status"] in [JobStatus.COMPLETED, JobStatus.FAILED]: | |
| jobs_to_clear.append(job_id) | |
| for job_id in jobs_to_clear: | |
| del job_storage[job_id] | |
| results.append(f"Cleared job: {job_id}") | |
| results.append(f"Cleared {len(jobs_to_clear)} completed/failed jobs") | |
| # Clear local images | |
| if clear_local_images: | |
| try: | |
| storage_info = get_local_storage_info() | |
| deleted_count = 0 | |
| if "images" in storage_info: | |
| for image_info in storage_info["images"]: | |
| success, _ = delete_local_image(image_info["path"]) | |
| if success: | |
| deleted_count += 1 | |
| results.append(f"Deleted {deleted_count} local images") | |
| except Exception as e: | |
| results.append(f"Error clearing local images: {str(e)}") | |
| # Force garbage collection | |
| if force_gc: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| results.append("GPU cache cleared") | |
| results.append("Garbage collection forced") | |
| # Get memory status after cleanup | |
| memory_status = get_memory_usage() | |
| return { | |
| "status": "success", | |
| "actions_performed": results, | |
| "memory_after_cleanup": memory_status | |
| } | |
| def load_model(model_name="dreamshaper-8"): | |
| """Thread-safe model loading with HIGH-QUALITY settings and better error handling""" | |
| global model_cache, current_model_name, current_pipe | |
| with model_lock: | |
| if model_name in model_cache: | |
| current_pipe = model_cache[model_name] | |
| current_model_name = model_name | |
| return current_pipe | |
| print(f"🔄 Loading HIGH-QUALITY model: {model_name}") | |
| try: | |
| model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8") | |
| print(f"🔧 Attempting to load: {model_id}") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| local_files_only=False, # Allow downloading if not cached | |
| cache_dir="./model_cache" # Specific cache directory | |
| ) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to("cpu") | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| current_model_name = model_name | |
| print(f"✅ HIGH-QUALITY Model loaded: {model_name}") | |
| return pipe | |
| except Exception as e: | |
| print(f"❌ Model loading failed for {model_name}: {e}") | |
| print(f"🔄 Falling back to stable-diffusion-v1-5") | |
| # Fallback to base model | |
| try: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ).to("cpu") | |
| model_cache[model_name] = pipe | |
| current_pipe = pipe | |
| current_model_name = "sd-1.5" | |
| print(f"✅ Fallback model loaded: stable-diffusion-v1-5") | |
| return pipe | |
| except Exception as fallback_error: | |
| print(f"❌ Critical: Fallback model also failed: {fallback_error}") | |
| raise | |
| # Initialize default model | |
| print("🚀 Initializing Storybook Generator API...") | |
| load_model("dreamshaper-8") | |
| print("✅ Model loaded and ready!") | |
| # SIMPLE PROMPT ENGINEERING - USE PURE PROMPTS ONLY | |
| def enhance_prompt_simple(scene_visual, style="childrens_book"): | |
| """Simple prompt enhancement - uses only the provided visual prompt with style""" | |
| # Style templates | |
| style_templates = { | |
| "childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration", | |
| "realistic": "photorealistic, detailed, natural lighting, professional photography", | |
| "fantasy": "fantasy art, magical, ethereal, digital painting, concept art", | |
| "anime": "anime style, Japanese animation, vibrant colors, detailed artwork" | |
| } | |
| style_prompt = style_templates.get(style, style_templates["childrens_book"]) | |
| # Use only the provided visual prompt with style | |
| enhanced_prompt = f"{style_prompt}, {scene_visual}" | |
| # Basic negative prompt for quality | |
| negative_prompt = ( | |
| "blurry, low quality, bad anatomy, deformed characters, " | |
| "wrong proportions, mismatched features" | |
| ) | |
| return enhanced_prompt, negative_prompt | |
| def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None): | |
| """Generate image using pure prompts only""" | |
| # Enhance prompt with simple style addition | |
| enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style) | |
| # Use seed if provided | |
| if consistency_seed: | |
| scene_seed = consistency_seed + scene_number | |
| else: | |
| scene_seed = random.randint(1000, 9999) | |
| try: | |
| pipe = load_model(model_choice) | |
| image = pipe( | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=35, | |
| guidance_scale=7.5, | |
| width=768, | |
| height=1024, # Portrait for better full-body | |
| generator=torch.Generator(device="cpu").manual_seed(scene_seed) | |
| ).images[0] | |
| print(f"✅ Generated image for scene {scene_number}") | |
| print(f"🌱 Seed used: {scene_seed}") | |
| print(f"📝 Pure prompt used: {prompt}") | |
| return image | |
| except Exception as e: | |
| print(f"❌ Generation failed: {str(e)}") | |
| raise | |
| # LOCAL FILE MANAGEMENT FUNCTIONS | |
| def save_image_to_local(image, prompt, style="test"): | |
| """Save image to local persistent storage""" | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_prompt = "".join(c for c in prompt[:50] if c.isalnum() or c in (' ', '-', '_')).rstrip() | |
| filename = f"image_{safe_prompt}_{timestamp}.png" | |
| # Create style subfolder | |
| style_dir = os.path.join(PERSISTENT_IMAGE_DIR, style) | |
| os.makedirs(style_dir, exist_ok=True) | |
| filepath = os.path.join(style_dir, filename) | |
| # Save the image | |
| image.save(filepath) | |
| print(f"💾 Image saved locally: {filepath}") | |
| return filepath, filename | |
| except Exception as e: | |
| print(f"❌ Failed to save locally: {e}") | |
| return None, None | |
| def delete_local_image(filepath): | |
| """Delete an image from local storage""" | |
| try: | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| print(f"🗑️ Deleted local image: {filepath}") | |
| return True, f"✅ Deleted: {os.path.basename(filepath)}" | |
| else: | |
| return False, f"❌ File not found: {filepath}" | |
| except Exception as e: | |
| return False, f"❌ Error deleting: {str(e)}" | |
| def get_local_storage_info(): | |
| """Get information about local storage usage""" | |
| try: | |
| total_size = 0 | |
| file_count = 0 | |
| images_list = [] | |
| for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR): | |
| for file in files: | |
| if file.endswith(('.png', '.jpg', '.jpeg')): | |
| filepath = os.path.join(root, file) | |
| if os.path.exists(filepath): | |
| file_size = os.path.getsize(filepath) | |
| total_size += file_size | |
| file_count += 1 | |
| images_list.append({ | |
| 'path': filepath, | |
| 'filename': file, | |
| 'size_kb': round(file_size / 1024, 1), | |
| 'created': os.path.getctime(filepath) | |
| }) | |
| return { | |
| "total_files": file_count, | |
| "total_size_mb": round(total_size / (1024 * 1024), 2), | |
| "images": sorted(images_list, key=lambda x: x['created'], reverse=True) | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def refresh_local_images(): | |
| """Get list of all locally saved images""" | |
| try: | |
| image_files = [] | |
| for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR): | |
| for file in files: | |
| if file.endswith(('.png', '.jpg', '.jpeg')): | |
| filepath = os.path.join(root, file) | |
| if os.path.exists(filepath): | |
| image_files.append(filepath) | |
| return image_files | |
| except Exception as e: | |
| print(f"Error refreshing local images: {e}") | |
| return [] | |
| # OCI BUCKET FUNCTIONS | |
| def save_to_oci_bucket(image, text_content, story_title, page_number, file_type="image"): | |
| """Save both images and text to OCI bucket via your OCI API with retry logic""" | |
| try: | |
| if file_type == "image": | |
| # Convert image to bytes | |
| img_bytes = io.BytesIO() | |
| image.save(img_bytes, format='PNG') | |
| file_data = img_bytes.getvalue() | |
| filename = f"page_{page_number:03d}.png" | |
| mime_type = "image/png" | |
| else: # text | |
| file_data = text_content.encode('utf-8') | |
| filename = f"page_{page_number:03d}.txt" | |
| mime_type = "text/plain" | |
| # Use your OCI API to save the file | |
| api_url = f"{OCI_API_BASE_URL}/api/upload" | |
| files = {'file': (filename, file_data, mime_type)} | |
| data = { | |
| 'project_id': 'storybook-library', | |
| 'subfolder': f'stories/{story_title}' | |
| } | |
| # Create session with retry strategy | |
| session = requests.Session() | |
| retry_strategy = Retry( | |
| total=3, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| allowed_methods=["POST"], | |
| backoff_factor=1 | |
| ) | |
| adapter = HTTPAdapter(max_retries=retry_strategy) | |
| session.mount("http://", adapter) | |
| session.mount("https://", adapter) | |
| # INCREASED TIMEOUT WITH RETRY LOGIC | |
| response = session.post(api_url, files=files, data=data, timeout=60) | |
| print(f"📨 OCI API Response: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| if result['status'] == 'success': | |
| return result.get('file_url', 'Unknown URL') | |
| else: | |
| raise Exception(f"OCI API Error: {result.get('message', 'Unknown error')}") | |
| else: | |
| raise Exception(f"HTTP Error: {response.status_code}") | |
| except Exception as e: | |
| raise Exception(f"OCI upload failed: {str(e)}") | |
| def test_oci_connection(): | |
| """Test connection to OCI API""" | |
| try: | |
| test_url = f"{OCI_API_BASE_URL}/api/health" | |
| print(f"🔧 Testing connection to: {test_url}") | |
| response = requests.get(test_url, timeout=10) | |
| print(f"🔧 Connection test response: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| print(f"🔧 OCI API Health: {result}") | |
| return True | |
| else: | |
| print(f"🔧 OCI API not healthy: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"🔧 Connection test failed: {e}") | |
| return False | |
| # JOB MANAGEMENT FUNCTIONS | |
| def create_job(story_request: StorybookRequest) -> str: | |
| job_id = str(uuid.uuid4()) | |
| job_storage[job_id] = { | |
| "status": JobStatus.PENDING, | |
| "progress": 0, | |
| "message": "Job created and queued", | |
| "request": story_request.dict(), | |
| "result": None, | |
| "created_at": time.time(), | |
| "updated_at": time.time(), | |
| "pages": [] | |
| } | |
| print(f"📝 Created job {job_id} for story: {story_request.story_title}") | |
| print(f"📄 Scenes to generate: {len(story_request.scenes)}") | |
| return job_id | |
| def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None): | |
| if job_id not in job_storage: | |
| return False | |
| job_storage[job_id].update({ | |
| "status": status, | |
| "progress": progress, | |
| "message": message, | |
| "updated_at": time.time() | |
| }) | |
| if result: | |
| job_storage[job_id]["result"] = result | |
| # Send webhook notification if callback URL exists | |
| job_data = job_storage[job_id] | |
| request_data = job_data["request"] | |
| if request_data.get("callback_url"): | |
| try: | |
| callback_url = request_data["callback_url"] | |
| # Enhanced callback data with scene information | |
| callback_data = { | |
| "job_id": job_id, | |
| "status": status.value, | |
| "progress": progress, | |
| "message": message, | |
| "story_title": request_data["story_title"], | |
| "total_scenes": len(request_data["scenes"]), | |
| "timestamp": time.time(), | |
| "source": "huggingface-storybook-generator", | |
| "estimated_time_remaining": calculate_remaining_time(job_id, progress) | |
| } | |
| # Add current scene info for processing jobs | |
| if status == JobStatus.PROCESSING: | |
| # Calculate current scene based on progress | |
| total_scenes = len(request_data["scenes"]) | |
| if total_scenes > 0: | |
| current_scene = min((progress - 5) // (90 // total_scenes) + 1, total_scenes) | |
| callback_data["current_scene"] = current_scene | |
| callback_data["total_scenes"] = total_scenes | |
| # Add scene description if available | |
| if current_scene <= len(request_data["scenes"]): | |
| scene_data = request_data["scenes"][current_scene-1] | |
| callback_data["scene_description"] = scene_data.get("visual", "")[:100] + "..." | |
| callback_data["current_prompt"] = scene_data.get("visual", "") | |
| # Add result data for completed jobs | |
| if status == JobStatus.COMPLETED and result: | |
| callback_data["result"] = { | |
| "total_pages": result.get("total_pages", 0), | |
| "generation_time": result.get("generation_time", 0), | |
| "oci_bucket_url": result.get("oci_bucket_url", ""), | |
| "pages_generated": result.get("generated_pages", 0), | |
| "consistency_seed": result.get("consistency_seed", None) | |
| } | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'User-Agent': 'Storybook-Generator/1.0' | |
| } | |
| print(f"📢 Sending callback to: {callback_url}") | |
| print(f"📊 Callback data: {json.dumps(callback_data, indent=2)}") | |
| response = requests.post( | |
| callback_url, | |
| json=callback_data, | |
| headers=headers, | |
| timeout=30 | |
| ) | |
| print(f"📢 Callback sent: Status {response.status_code}") | |
| except Exception as e: | |
| print(f"⚠️ Callback failed: {str(e)}") | |
| return True | |
| def calculate_remaining_time(job_id, progress): | |
| """Calculate estimated time remaining""" | |
| if progress == 0: | |
| return "Calculating..." | |
| job_data = job_storage.get(job_id) | |
| if not job_data: | |
| return "Unknown" | |
| time_elapsed = time.time() - job_data["created_at"] | |
| if progress > 0: | |
| total_estimated = (time_elapsed / progress) * 100 | |
| remaining = total_estimated - time_elapsed | |
| return f"{int(remaining // 60)}m {int(remaining % 60)}s" | |
| return "Unknown" | |
| # SIMPLE BACKGROUND TASK - USES PURE PROMPTS ONLY | |
| def generate_storybook_background(job_id: str): | |
| """Background task to generate complete storybook using pure prompts only""" | |
| try: | |
| # Test OCI connection first | |
| print("🔧 Testing OCI API connection...") | |
| oci_connected = test_oci_connection() | |
| if not oci_connected: | |
| print("⚠️ OCI API connection test failed - will use local fallback") | |
| job_data = job_storage[job_id] | |
| story_request_data = job_data["request"] | |
| story_request = StorybookRequest(**story_request_data) | |
| print(f"🎬 Starting storybook generation for job {job_id}") | |
| print(f"📖 Story: {story_request.story_title}") | |
| print(f"📄 Scenes: {len(story_request.scenes)}") | |
| print(f"🎨 Style: {story_request.style}") | |
| print(f"🌱 Consistency seed: {story_request.consistency_seed}") | |
| update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting storybook generation with pure prompts...") | |
| total_scenes = len(story_request.scenes) | |
| generated_pages = [] | |
| start_time = time.time() | |
| for i, scene in enumerate(story_request.scenes): | |
| # FIXED: Better progress calculation | |
| progress = 5 + int(((i + 1) / total_scenes) * 90) | |
| update_job_status( | |
| job_id, | |
| JobStatus.PROCESSING, | |
| progress, | |
| f"Generating page {i+1}/{total_scenes}: {scene.visual[:50]}..." | |
| ) | |
| try: | |
| print(f"🖼️ Generating page {i+1}") | |
| print(f"📝 Pure prompt: {scene.visual}") | |
| # Generate image using pure prompt only | |
| image = generate_image_simple( | |
| scene.visual, | |
| story_request.model_choice, | |
| story_request.style, | |
| i + 1, | |
| story_request.consistency_seed | |
| ) | |
| # Save locally as backup | |
| local_filepath, local_filename = save_image_to_local(image, scene.visual, story_request.style) | |
| print(f"💾 Image saved locally as backup: {local_filename}") | |
| try: | |
| # Save IMAGE to OCI bucket | |
| image_url = save_to_oci_bucket( | |
| image, | |
| "", # No text for image | |
| story_request.story_title, | |
| i + 1, | |
| "image" | |
| ) | |
| # Save TEXT to OCI bucket | |
| text_url = save_to_oci_bucket( | |
| None, # No image for text | |
| scene.text, | |
| story_request.story_title, | |
| i + 1, | |
| "text" | |
| ) | |
| # Store page data | |
| page_data = { | |
| "page_number": i + 1, | |
| "image_url": image_url, | |
| "text_url": text_url, | |
| "text_content": scene.text, | |
| "visual_description": scene.visual, | |
| "prompt_used": scene.visual, # Store the pure prompt | |
| "local_backup_path": local_filepath | |
| } | |
| generated_pages.append(page_data) | |
| print(f"✅ Page {i+1} completed") | |
| except Exception as upload_error: | |
| # If OCI upload fails, use local file as fallback | |
| error_msg = f"OCI upload failed for page {i+1}, using local backup: {str(upload_error)}" | |
| print(f"⚠️ {error_msg}") | |
| page_data = { | |
| "page_number": i + 1, | |
| "image_url": f"local://{local_filepath}", | |
| "text_url": f"local://text_content_{i+1}", | |
| "text_content": scene.text, | |
| "visual_description": scene.visual, | |
| "prompt_used": scene.visual, | |
| "local_backup_path": local_filepath, | |
| "upload_error": str(upload_error) | |
| } | |
| generated_pages.append(page_data) | |
| # Continue with next page instead of failing completely | |
| continue | |
| except Exception as e: | |
| error_msg = f"Failed to generate page {i+1}: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| update_job_status(job_id, JobStatus.FAILED, 0, error_msg) | |
| return | |
| # Complete the job | |
| generation_time = time.time() - start_time | |
| # Count successful OCI uploads vs local fallbacks | |
| oci_success_count = sum(1 for page in generated_pages if not page.get("upload_error")) | |
| local_fallback_count = sum(1 for page in generated_pages if page.get("upload_error")) | |
| result = { | |
| "story_title": story_request.story_title, | |
| "total_pages": total_scenes, | |
| "generated_pages": len(generated_pages), | |
| "generation_time": round(generation_time, 2), | |
| "folder_path": f"stories/{story_request.story_title}", | |
| "oci_bucket_url": f"https://oci.com/stories/{story_request.story_title}", | |
| "consistency_seed": story_request.consistency_seed, | |
| "pages": generated_pages, | |
| "file_structure": { | |
| "images": [f"page_{i+1:03d}.png" for i in range(total_scenes)], | |
| "texts": [f"page_{i+1:03d}.txt" for i in range(total_scenes)] | |
| }, | |
| "upload_summary": { | |
| "oci_successful": oci_success_count, | |
| "local_fallback": local_fallback_count, | |
| "total_attempted": total_scenes | |
| } | |
| } | |
| status_message = f"🎉 Storybook completed! {len(generated_pages)} pages created in {generation_time:.2f}s using pure prompts." | |
| if local_fallback_count > 0: | |
| status_message += f" {local_fallback_count} pages saved locally due to OCI upload issues." | |
| update_job_status( | |
| job_id, | |
| JobStatus.COMPLETED, | |
| 100, | |
| status_message, | |
| result | |
| ) | |
| print(f"🎉 Storybook generation finished for job {job_id}") | |
| print(f"📁 OCI Uploads: {oci_success_count} successful, {local_fallback_count} local fallbacks") | |
| print(f"📝 All prompts used exactly as provided from Telegram") | |
| except Exception as e: | |
| error_msg = f"Story generation failed: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| update_job_status(job_id, JobStatus.FAILED, 0, error_msg) | |
| # FASTAPI ENDPOINTS (for n8n) | |
| async def generate_storybook(request: dict, background_tasks: BackgroundTasks): | |
| """Main endpoint for n8n integration - generates complete storybook using pure prompts""" | |
| try: | |
| print(f"📥 Received n8n request for story: {request.get('story_title', 'Unknown')}") | |
| # Add consistency seed if not provided | |
| if 'consistency_seed' not in request or not request['consistency_seed']: | |
| request['consistency_seed'] = random.randint(1000, 9999) | |
| print(f"🌱 Generated consistency seed: {request['consistency_seed']}") | |
| # Convert to Pydantic model | |
| story_request = StorybookRequest(**request) | |
| # Validate required fields | |
| if not story_request.story_title or not story_request.scenes: | |
| raise HTTPException(status_code=400, detail="story_title and scenes are required") | |
| # Create job immediately | |
| job_id = create_job(story_request) | |
| # Start background processing | |
| background_tasks.add_task(generate_storybook_background, job_id) | |
| # Immediate response for n8n | |
| response_data = { | |
| "status": "success", | |
| "message": "Storybook generation with pure prompts started successfully", | |
| "job_id": job_id, | |
| "story_title": story_request.story_title, | |
| "total_scenes": len(story_request.scenes), | |
| "consistency_seed": story_request.consistency_seed, | |
| "callback_url": story_request.callback_url, | |
| "estimated_time_seconds": len(story_request.scenes) * 35, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| print(f"✅ Job {job_id} started with pure prompts for: {story_request.story_title}") | |
| return response_data | |
| except Exception as e: | |
| error_msg = f"API Error: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| async def get_job_status_endpoint(job_id: str): | |
| """Check job status""" | |
| job_data = job_storage.get(job_id) | |
| if not job_data: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return JobStatusResponse( | |
| job_id=job_id, | |
| status=job_data["status"], | |
| progress=job_data["progress"], | |
| message=job_data["message"], | |
| result=job_data["result"], | |
| created_at=job_data["created_at"], | |
| updated_at=job_data["updated_at"] | |
| ) | |
| async def api_health(): | |
| """Health check endpoint for n8n""" | |
| return { | |
| "status": "healthy", | |
| "service": "storybook-generator", | |
| "timestamp": datetime.now().isoformat(), | |
| "active_jobs": len(job_storage), | |
| "models_loaded": list(model_cache.keys()), | |
| "oci_api_connected": OCI_API_BASE_URL | |
| } | |
| # NEW MEMORY MANAGEMENT ENDPOINTS | |
| async def get_memory_status(): | |
| """Get current memory usage and system status""" | |
| memory_info = get_memory_usage() | |
| return MemoryStatusResponse( | |
| memory_used_mb=memory_info["memory_used_mb"], | |
| memory_percent=memory_info["memory_percent"], | |
| models_loaded=memory_info["models_loaded"], | |
| active_jobs=memory_info["active_jobs"], | |
| local_images_count=memory_info["local_images_count"], | |
| gpu_memory_allocated_mb=memory_info["gpu_memory_allocated_mb"], | |
| gpu_memory_cached_mb=memory_info["gpu_memory_cached_mb"], | |
| status="healthy" | |
| ) | |
| async def clear_memory_endpoint(request: MemoryClearanceRequest): | |
| """Clear memory by unloading models and cleaning up resources""" | |
| try: | |
| result = clear_memory( | |
| clear_models=request.clear_models, | |
| clear_jobs=request.clear_jobs, | |
| clear_local_images=request.clear_local_images, | |
| force_gc=request.force_gc | |
| ) | |
| return { | |
| "status": "success", | |
| "message": "Memory clearance completed", | |
| "details": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Memory clearance failed: {str(e)}") | |
| async def auto_cleanup(): | |
| """Automatic cleanup - clears completed jobs and forces GC""" | |
| try: | |
| result = clear_memory( | |
| clear_models=False, # Don't clear models by default | |
| clear_jobs=True, # Clear completed jobs | |
| clear_local_images=False, # Don't clear images by default | |
| force_gc=True # Force garbage collection | |
| ) | |
| return { | |
| "status": "success", | |
| "message": "Automatic cleanup completed", | |
| "details": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Auto cleanup failed: {str(e)}") | |
| async def get_local_images(): | |
| """API endpoint to get locally saved test images""" | |
| storage_info = get_local_storage_info() | |
| return storage_info | |
| async def delete_local_image_api(filename: str): | |
| """API endpoint to delete a local image""" | |
| try: | |
| filepath = os.path.join(PERSISTENT_IMAGE_DIR, filename) | |
| success, message = delete_local_image(filepath) | |
| return {"status": "success" if success else "error", "message": message} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| # SIMPLE GRADIO INTERFACE | |
| def create_gradio_interface(): | |
| """Create simple Gradio interface for testing""" | |
| def generate_test_image_simple(prompt, model_choice, style_choice): | |
| """Generate a single image using pure prompt only""" | |
| try: | |
| if not prompt.strip(): | |
| return None, "❌ Please enter a prompt", None | |
| print(f"🎨 Generating test image with pure prompt: {prompt}") | |
| # Generate the image using pure prompt | |
| image = generate_image_simple( | |
| prompt, | |
| model_choice, | |
| style_choice, | |
| 1 | |
| ) | |
| # Save to local storage | |
| filepath, filename = save_image_to_local(image, prompt, style_choice) | |
| status_msg = f"""✅ Success! Generated: {prompt} | |
| 📁 **Local file:** {filename if filename else 'Not saved'}""" | |
| return image, status_msg, filepath | |
| except Exception as e: | |
| error_msg = f"❌ Generation failed: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg, None | |
| with gr.Blocks(title="Simple Image Generator", theme="soft") as demo: | |
| gr.Markdown("# 🎨 Simple Image Generator") | |
| gr.Markdown("Generate images using **pure prompts only** - no automatic enhancements") | |
| # Storage info display | |
| storage_info = gr.Textbox( | |
| label="📊 Local Storage Information", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| # Memory status display | |
| memory_status = gr.Textbox( | |
| label="🧠 Memory Status", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| def update_storage_info(): | |
| info = get_local_storage_info() | |
| if "error" not in info: | |
| return f"📁 Local Storage: {info['total_files']} images, {info['total_size_mb']} MB used" | |
| return "📁 Local Storage: Unable to calculate" | |
| def update_memory_status(): | |
| memory_info = get_memory_usage() | |
| status_text = f"🧠 Memory Usage: {memory_info['memory_used_mb']} MB ({memory_info['memory_percent']}%)\n" | |
| status_text += f"📦 Models Loaded: {memory_info['models_loaded']}\n" | |
| status_text += f"⚡ Active Jobs: {memory_info['active_jobs']}" | |
| if memory_info['gpu_memory_allocated_mb']: | |
| status_text += f"\n🎮 GPU Memory: {memory_info['gpu_memory_allocated_mb']} MB allocated" | |
| return status_text | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎯 Quality Settings") | |
| model_dropdown = gr.Dropdown( | |
| label="AI Model", | |
| choices=list(MODEL_CHOICES.keys()), | |
| value="dreamshaper-8" | |
| ) | |
| style_dropdown = gr.Dropdown( | |
| label="Art Style", | |
| choices=["childrens_book", "realistic", "fantasy", "anime"], | |
| value="anime" | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Pure Prompt", | |
| placeholder="Enter your exact prompt...", | |
| lines=3 | |
| ) | |
| generate_btn = gr.Button("✨ Generate Image", variant="primary") | |
| # Current image management | |
| current_file_path = gr.State() | |
| delete_btn = gr.Button("🗑️ Delete This Image", variant="stop") | |
| delete_status = gr.Textbox(label="Delete Status", interactive=False, lines=2) | |
| # Memory management section | |
| gr.Markdown("### 🧠 Memory Management") | |
| with gr.Row(): | |
| auto_cleanup_btn = gr.Button("🔄 Auto Cleanup", size="sm") | |
| clear_models_btn = gr.Button("🗑️ Clear Models", variant="stop", size="sm") | |
| memory_clear_status = gr.Textbox(label="Memory Clear Status", interactive=False, lines=2) | |
| gr.Markdown("### 📚 API Usage for n8n") | |
| gr.Markdown(""" | |
| **For complete storybooks (OCI bucket):** | |
| - Endpoint: `POST /api/generate-storybook` | |
| - Input: `story_title`, `scenes[]`, `characters[]` | |
| - Output: Uses pure prompts only from your script | |
| **Memory Management APIs:** | |
| - `GET /api/memory-status` - Check memory usage | |
| - `POST /api/clear-memory` - Clear memory | |
| - `POST /api/auto-cleanup` - Auto cleanup jobs | |
| """) | |
| with gr.Column(scale=2): | |
| image_output = gr.Image(label="Generated Image", height=500, show_download_button=True) | |
| status_output = gr.Textbox(label="Status", interactive=False, lines=4) | |
| # Local file management section | |
| with gr.Accordion("📁 Manage Local Test Images", open=True): | |
| gr.Markdown("### Locally Saved Images") | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 Refresh List") | |
| clear_all_btn = gr.Button("🗑️ Clear All Images", variant="stop") | |
| file_gallery = gr.Gallery( | |
| label="Local Images", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=4, | |
| height="auto" | |
| ) | |
| clear_status = gr.Textbox(label="Clear Status", interactive=False) | |
| def delete_current_image(filepath): | |
| """Delete the currently displayed image""" | |
| if not filepath: | |
| return "❌ No image to delete", None, None, refresh_local_images() | |
| success, message = delete_local_image(filepath) | |
| updated_files = refresh_local_images() | |
| if success: | |
| status_msg = f"✅ {message}" | |
| return status_msg, None, "Image deleted successfully!", updated_files | |
| else: | |
| return f"❌ {message}", None, "Delete failed", updated_files | |
| def clear_all_images(): | |
| """Delete all local images""" | |
| try: | |
| storage_info = get_local_storage_info() | |
| deleted_count = 0 | |
| if "images" in storage_info: | |
| for image_info in storage_info["images"]: | |
| success, _ = delete_local_image(image_info["path"]) | |
| if success: | |
| deleted_count += 1 | |
| updated_files = refresh_local_images() | |
| return f"✅ Deleted {deleted_count} images", updated_files | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", refresh_local_images() | |
| def perform_auto_cleanup(): | |
| """Perform automatic cleanup""" | |
| try: | |
| result = clear_memory( | |
| clear_models=False, | |
| clear_jobs=True, | |
| clear_local_images=False, | |
| force_gc=True | |
| ) | |
| return f"✅ Auto cleanup completed: {len(result['actions_performed'])} actions" | |
| except Exception as e: | |
| return f"❌ Auto cleanup failed: {str(e)}" | |
| def clear_models(): | |
| """Clear all loaded models""" | |
| try: | |
| result = clear_memory( | |
| clear_models=True, | |
| clear_jobs=False, | |
| clear_local_images=False, | |
| force_gc=True | |
| ) | |
| return f"✅ Models cleared: {len(result['actions_performed'])} actions" | |
| except Exception as e: | |
| return f"❌ Model clearance failed: {str(e)}" | |
| # Connect buttons to functions | |
| generate_btn.click( | |
| fn=generate_test_image_simple, | |
| inputs=[prompt_input, model_dropdown, style_dropdown], | |
| outputs=[image_output, status_output, current_file_path] | |
| ).then( | |
| fn=refresh_local_images, | |
| outputs=file_gallery | |
| ).then( | |
| fn=update_storage_info, | |
| outputs=storage_info | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| delete_btn.click( | |
| fn=delete_current_image, | |
| inputs=current_file_path, | |
| outputs=[delete_status, image_output, status_output, file_gallery] | |
| ).then( | |
| fn=update_storage_info, | |
| outputs=storage_info | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_local_images, | |
| outputs=file_gallery | |
| ).then( | |
| fn=update_storage_info, | |
| outputs=storage_info | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| clear_all_btn.click( | |
| fn=clear_all_images, | |
| outputs=[clear_status, file_gallery] | |
| ).then( | |
| fn=update_storage_info, | |
| outputs=storage_info | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| # Memory management buttons | |
| auto_cleanup_btn.click( | |
| fn=perform_auto_cleanup, | |
| outputs=memory_clear_status | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| clear_models_btn.click( | |
| fn=clear_models, | |
| outputs=memory_clear_status | |
| ).then( | |
| fn=update_memory_status, | |
| outputs=memory_status | |
| ) | |
| # Initialize on load | |
| demo.load(fn=refresh_local_images, outputs=file_gallery) | |
| demo.load(fn=update_storage_info, outputs=storage_info) | |
| demo.load(fn=update_memory_status, outputs=memory_status) | |
| return demo | |
| # Create simple Gradio app | |
| demo = create_gradio_interface() | |
| # Simple root endpoint | |
| async def root(): | |
| return { | |
| "message": "Simple Storybook Generator API is running!", | |
| "api_endpoints": { | |
| "health_check": "GET /api/health", | |
| "generate_storybook": "POST /api/generate-storybook", | |
| "check_job_status": "GET /api/job-status/{job_id}", | |
| "local_images": "GET /api/local-images", | |
| "memory_status": "GET /api/memory-status", | |
| "clear_memory": "POST /api/clear-memory", | |
| "auto_cleanup": "POST /api/auto-cleanup" | |
| }, | |
| "features": { | |
| "pure_prompts": "✅ Enabled - No automatic enhancements", | |
| "n8n_integration": "✅ Enabled", | |
| "memory_management": "✅ Enabled" | |
| }, | |
| "web_interface": "GET /ui" | |
| } | |
| # Add a simple test endpoint | |
| async def test_endpoint(): | |
| return { | |
| "status": "success", | |
| "message": "API with pure prompts is working correctly", | |
| "pure_prompts": "✅ Enabled - Using exact prompts from Telegram", | |
| "memory_management": "✅ Enabled - Memory clearance available", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # For Hugging Face Spaces deployment | |
| def get_app(): | |
| return app | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Check if we're running on Hugging Face Spaces | |
| HF_SPACE = os.environ.get('SPACE_ID') is not None | |
| if HF_SPACE: | |
| print("🚀 Running on Hugging Face Spaces - Integrated Mode") | |
| print("📚 API endpoints available at: /api/*") | |
| print("🎨 Web interface available at: /ui") | |
| print("📝 PURE PROMPTS enabled - no automatic enhancements") | |
| print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available") | |
| # Mount Gradio without reassigning app | |
| gr.mount_gradio_app(app, demo, path="/ui") | |
| # Run the combined app | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| else: | |
| # Local development - run separate servers | |
| print("🚀 Running locally - Separate API and UI servers") | |
| print("📚 API endpoints: http://localhost:8000/api/*") | |
| print("🎨 Web interface: http://localhost:7860/ui") | |
| print("📝 PURE PROMPTS enabled - no automatic enhancements") | |
| print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available") | |
| def run_fastapi(): | |
| """Run FastAPI on port 8000 for API calls""" | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| log_level="info", | |
| access_log=False | |
| ) | |
| def run_gradio(): | |
| """Run Gradio on port 7860 for web interface""" | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| # Run both servers in separate threads | |
| import threading | |
| fastapi_thread = threading.Thread(target=run_fastapi, daemon=True) | |
| gradio_thread = threading.Thread(target=run_gradio, daemon=True) | |
| fastapi_thread.start() | |
| gradio_thread.start() | |
| try: | |
| # Keep main thread alive | |
| while True: | |
| time.sleep(1) | |
| except KeyboardInterrupt: | |
| print("🛑 Shutting down servers...") |