yukee1992 commited on
Commit
1fe624c
·
verified ·
1 Parent(s): 0ccca52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1238 -155
app.py CHANGED
@@ -1,239 +1,1322 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline
4
  from PIL import Image
5
  import io
 
6
  import os
7
  from datetime import datetime
 
8
  import time
9
  import json
 
 
 
 
 
 
10
  import uuid
 
 
11
  import random
12
- import threading
13
- from typing import Optional
14
- from fastapi import FastAPI, BackgroundTasks
15
- from pydantic import BaseModel
16
- import requests
17
- from huggingface_hub import HfApi
18
 
19
  # =============================================
20
- # CONFIGURATION
21
  # =============================================
22
  HF_TOKEN = os.environ.get("HF_TOKEN")
23
  HF_USERNAME = "yukee1992"
24
  DATASET_NAME = "video-project-images"
25
  DATASET_ID = f"{HF_USERNAME}/{DATASET_NAME}"
26
 
27
- print("=" * 60)
28
- print("🚀 STARTING IMAGE GENERATOR")
29
- print("=" * 60)
30
  print(f"📦 HF Dataset: {DATASET_ID}")
31
  print(f"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing'}")
32
 
33
- # Create backup directory
34
- BACKUP_DIR = "generated_images_backup"
35
- os.makedirs(BACKUP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Initialize FastAPI
38
- app = FastAPI(title="Image Generator API")
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Global model cache
41
- model = None
 
 
 
42
  model_lock = threading.Lock()
43
 
44
- # =============================================
45
- # MODEL LOADING
46
- # =============================================
47
- def load_model():
48
- global model
49
- if model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with model_lock:
51
- if model is None:
52
- print("🔄 Loading model...")
53
- model = StableDiffusionPipeline.from_pretrained(
54
- "runwayml/stable-diffusion-v1-5",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  torch_dtype=torch.float32,
56
- safety_checker=None
 
57
  ).to("cpu")
58
- print("✅ Model loaded!")
59
- return model
 
 
 
 
 
 
 
 
 
60
 
61
- # Preload model at startup
62
- load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # =============================================
65
- # HF DATASET FUNCTIONS
66
  # =============================================
67
- def upload_to_hf_dataset(image, project_id, scene_num):
68
- """Upload image to HF Dataset"""
 
69
  if not HF_TOKEN:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return None
71
 
72
  try:
73
- # Convert image to bytes
74
- img_bytes = io.BytesIO()
75
- image.save(img_bytes, format='PNG')
76
- img_data = img_bytes.getvalue()
77
-
78
- # Create filename
79
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
80
- filename = f"scene_{scene_num:03d}_{timestamp}.png"
81
- path_in_repo = f"data/projects/{project_id}/{filename}"
82
 
83
- # Upload
84
  api = HfApi(token=HF_TOKEN)
85
  api.upload_file(
86
- path_or_fileobj=img_data,
87
  path_in_repo=path_in_repo,
88
  repo_id=DATASET_ID,
89
  repo_type="dataset"
90
  )
91
 
92
  url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}"
 
93
  return url
 
94
  except Exception as e:
95
- print(f"❌ Upload failed: {e}")
96
  return None
97
 
98
- # =============================================
99
- # IMAGE GENERATION
100
- # =============================================
101
- def generate_image(prompt, project_id=None, scene_num=1):
102
- """Generate a single image"""
103
  try:
104
- pipe = load_model()
 
 
105
 
106
- # Generate
107
- image = pipe(
108
- prompt,
109
- num_inference_steps=25,
110
- guidance_scale=7.5,
111
- generator=torch.Generator(device="cpu").manual_seed(random.randint(1, 999999))
112
- ).images[0]
113
 
114
- # Save locally
115
- local_path = os.path.join(BACKUP_DIR, f"{uuid.uuid4()}.png")
116
- image.save(local_path)
117
 
118
- # Upload to HF Dataset if project_id provided
119
- hf_url = None
120
- if project_id:
121
- hf_url = upload_to_hf_dataset(image, project_id, scene_num)
122
 
123
- return {
124
- "image": image,
125
- "local_path": local_path,
126
- "hf_url": hf_url
127
- }
128
  except Exception as e:
129
- print(f"❌ Generation failed: {e}")
130
- raise
131
 
132
  # =============================================
133
- # API ENDPOINTS
 
134
  # =============================================
135
- class GenerateRequest(BaseModel):
136
- prompt: str
137
- project_id: Optional[str] = None
138
 
139
- @app.post("/api/generate") # Changed to /api/generate
140
- async def generate(request: GenerateRequest):
141
- """Generate a single image"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  try:
143
- result = generate_image(request.prompt, request.project_id)
144
- return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  "status": "success",
146
- "hf_url": result["hf_url"],
147
- "local_path": result["local_path"]
 
 
 
 
 
 
 
 
148
  }
 
 
 
 
 
149
  except Exception as e:
150
- return {"status": "error", "message": str(e)}
 
 
151
 
152
- @app.get("/api/health") # Changed to /api/health
153
- async def health():
154
- """Health check"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return {
156
  "status": "healthy",
157
- "model_loaded": model is not None,
158
- "hf_dataset": DATASET_ID if HF_TOKEN else "disabled"
 
 
 
 
159
  }
160
 
161
- @app.get("/api/test")
162
- async def api_test():
163
- """Simple test endpoint"""
164
- return {"message": "API is working!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # =============================================
167
- # GRADIO INTERFACE
168
- # =============================================
169
- def gradio_generate(prompt):
170
- if not prompt:
171
- return None
172
- result = generate_image(prompt)
173
- return result["image"]
174
-
175
- # Create Gradio interface
176
- with gr.Blocks(title="Image Generator", theme=gr.themes.Soft()) as demo:
177
- gr.Markdown("# 🎨 Image Generator")
178
- gr.Markdown("Generate images using Stable Diffusion")
179
-
180
- with gr.Row():
181
- with gr.Column():
182
- prompt_input = gr.Textbox(
183
- label="Prompt",
184
- placeholder="Enter your prompt...",
185
- lines=3
186
- )
187
- generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
188
 
189
- with gr.Column():
190
- output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- generate_btn.click(
193
- fn=gradio_generate,
194
- inputs=[prompt_input],
195
- outputs=[output_image]
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- gr.Markdown("---")
199
- gr.Markdown("### API Endpoints")
200
- gr.Markdown("""
201
- - `GET /api/health` - Health check
202
- - `POST /api/generate` - Generate image
203
- - `GET /api/test` - Test endpoint
204
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- # =============================================
207
- # ROOT ENDPOINT
208
- # =============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  @app.get("/")
210
  async def root():
211
  return {
212
- "name": "Image Generator",
213
- "version": "1.0.0",
214
  "api_endpoints": {
215
- "health": "GET /api/health",
216
- "generate": "POST /api/generate",
217
- "test": "GET /api/test"
 
 
 
 
 
 
 
 
 
 
 
 
218
  },
219
- "ui": "Gradio interface available at /ui",
220
- "status": "running"
221
  }
222
 
223
- # =============================================
224
- # MAIN - Hugging Face Spaces Deployment
225
- # =============================================
 
 
 
 
 
 
 
 
 
 
 
226
  if __name__ == "__main__":
227
  import uvicorn
 
228
 
229
- print("\n" + "=" * 60)
230
- print("🌐 Deploying on Hugging Face Spaces")
231
- print("📚 API endpoints: /api/*")
232
- print("🎨 UI: /ui")
233
- print("=" * 60)
234
-
235
- # Mount Gradio at /ui path
236
- app = gr.mount_gradio_app(app, demo, path="/ui")
237
 
238
- # Run the combined app
239
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
4
  from PIL import Image
5
  import io
6
+ import requests
7
  import os
8
  from datetime import datetime
9
+ import re
10
  import time
11
  import json
12
+ from typing import List, Optional, Dict
13
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
14
+ from pydantic import BaseModel
15
+ import gc
16
+ import psutil
17
+ import threading
18
  import uuid
19
+ import hashlib
20
+ from enum import Enum
21
  import random
22
+ import time
23
+ from requests.adapters import HTTPAdapter
24
+ from urllib3.util.retry import Retry
25
+ from huggingface_hub import HfApi # NEW: Add this import
 
 
26
 
27
  # =============================================
28
+ # HUGGING FACE DATASET CONFIGURATION (NEW)
29
  # =============================================
30
  HF_TOKEN = os.environ.get("HF_TOKEN")
31
  HF_USERNAME = "yukee1992"
32
  DATASET_NAME = "video-project-images"
33
  DATASET_ID = f"{HF_USERNAME}/{DATASET_NAME}"
34
 
 
 
 
35
  print(f"📦 HF Dataset: {DATASET_ID}")
36
  print(f"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing'}")
37
 
38
+ # Create local directories for test images
39
+ PERSISTENT_IMAGE_DIR = "generated_test_images"
40
+ os.makedirs(PERSISTENT_IMAGE_DIR, exist_ok=True)
41
+ print(f"📁 Created local image directory: {PERSISTENT_IMAGE_DIR}")
42
+
43
+ # Initialize FastAPI app
44
+ app = FastAPI(title="Storybook Generator API")
45
+
46
+ # Add CORS middleware
47
+ from fastapi.middleware.cors import CORSMiddleware
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"],
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
+ )
55
+
56
+ # Job Status Enum
57
+ class JobStatus(str, Enum):
58
+ PENDING = "pending"
59
+ PROCESSING = "processing"
60
+ COMPLETED = "completed"
61
+ FAILED = "failed"
62
+
63
+ # Simple Story scene model
64
+ class StoryScene(BaseModel):
65
+ visual: str
66
+ text: str
67
+
68
+ class CharacterDescription(BaseModel):
69
+ name: str
70
+ description: str
71
+
72
+ class StorybookRequest(BaseModel):
73
+ story_title: str
74
+ scenes: List[StoryScene]
75
+ characters: List[CharacterDescription] = []
76
+ model_choice: str = "dreamshaper-8"
77
+ style: str = "childrens_book"
78
+ callback_url: Optional[str] = None
79
+ consistency_seed: Optional[int] = None
80
+ project_id: Optional[str] = None # ADDED for HF Dataset organization
81
+
82
+ class JobStatusResponse(BaseModel):
83
+ job_id: str
84
+ status: JobStatus
85
+ progress: int
86
+ message: str
87
+ result: Optional[dict] = None
88
+ created_at: float
89
+ updated_at: float
90
+
91
+ class MemoryClearanceRequest(BaseModel):
92
+ clear_models: bool = True
93
+ clear_jobs: bool = False
94
+ clear_local_images: bool = False
95
+ force_gc: bool = True
96
+
97
+ class MemoryStatusResponse(BaseModel):
98
+ memory_used_mb: float
99
+ memory_percent: float
100
+ models_loaded: int
101
+ active_jobs: int
102
+ local_images_count: int
103
+ gpu_memory_allocated_mb: Optional[float] = None
104
+ gpu_memory_cached_mb: Optional[float] = None
105
+ status: str
106
 
107
+ # HIGH-QUALITY MODEL SELECTION - ANIME FOCUSED & WORKING
108
+ MODEL_CHOICES = {
109
+ "dreamshaper-8": "lykon/dreamshaper-8",
110
+ "realistic-vision": "SG161222/Realistic_Vision_V5.1",
111
+ "counterfeit": "gsdf/Counterfeit-V2.5",
112
+ "pastel-mix": "andite/pastel-mix",
113
+ "meina-mix": "Meina/MeinaMix",
114
+ "meina-pastel": "Meina/MeinaPastel",
115
+ "abyss-orange": "warriorxza/AbyssOrangeMix",
116
+ "openjourney": "prompthero/openjourney",
117
+ "sd-1.5": "runwayml/stable-diffusion-v1-5",
118
+ }
119
 
120
+ # GLOBAL STORAGE
121
+ job_storage = {}
122
+ model_cache = {}
123
+ current_model_name = None
124
+ current_pipe = None
125
  model_lock = threading.Lock()
126
 
127
+ # MEMORY MANAGEMENT FUNCTIONS
128
+ def get_memory_usage():
129
+ """Get current memory usage statistics"""
130
+ process = psutil.Process()
131
+ memory_info = process.memory_info()
132
+ memory_used_mb = memory_info.rss / (1024 * 1024)
133
+ memory_percent = process.memory_percent()
134
+
135
+ # GPU memory if available
136
+ gpu_memory_allocated_mb = None
137
+ gpu_memory_cached_mb = None
138
+
139
+ if torch.cuda.is_available():
140
+ gpu_memory_allocated_mb = torch.cuda.memory_allocated() / (1024 * 1024)
141
+ gpu_memory_cached_mb = torch.cuda.memory_reserved() / (1024 * 1024)
142
+
143
+ return {
144
+ "memory_used_mb": round(memory_used_mb, 2),
145
+ "memory_percent": round(memory_percent, 2),
146
+ "gpu_memory_allocated_mb": round(gpu_memory_allocated_mb, 2) if gpu_memory_allocated_mb else None,
147
+ "gpu_memory_cached_mb": round(gpu_memory_cached_mb, 2) if gpu_memory_cached_mb else None,
148
+ "models_loaded": len(model_cache),
149
+ "active_jobs": len(job_storage),
150
+ "local_images_count": len(refresh_local_images())
151
+ }
152
+
153
+ def clear_memory(clear_models=True, clear_jobs=False, clear_local_images=False, force_gc=True):
154
+ """Clear memory by unloading models and cleaning up resources"""
155
+ results = []
156
+
157
+ # Clear model cache
158
+ if clear_models:
159
  with model_lock:
160
+ models_cleared = len(model_cache)
161
+ for model_name, pipe in model_cache.items():
162
+ try:
163
+ # Move to CPU first if it's on GPU
164
+ if hasattr(pipe, 'to'):
165
+ pipe.to('cpu')
166
+
167
+ # Delete the pipeline
168
+ del pipe
169
+ results.append(f"Unloaded model: {model_name}")
170
+ except Exception as e:
171
+ results.append(f"Error unloading {model_name}: {str(e)}")
172
+
173
+ model_cache.clear()
174
+ global current_pipe, current_model_name
175
+ current_pipe = None
176
+ current_model_name = None
177
+ results.append(f"Cleared {models_cleared} models from cache")
178
+
179
+ # Clear completed jobs
180
+ if clear_jobs:
181
+ jobs_to_clear = []
182
+ for job_id, job_data in job_storage.items():
183
+ if job_data["status"] in [JobStatus.COMPLETED, JobStatus.FAILED]:
184
+ jobs_to_clear.append(job_id)
185
+
186
+ for job_id in jobs_to_clear:
187
+ del job_storage[job_id]
188
+ results.append(f"Cleared job: {job_id}")
189
+
190
+ results.append(f"Cleared {len(jobs_to_clear)} completed/failed jobs")
191
+
192
+ # Clear local images
193
+ if clear_local_images:
194
+ try:
195
+ storage_info = get_local_storage_info()
196
+ deleted_count = 0
197
+ if "images" in storage_info:
198
+ for image_info in storage_info["images"]:
199
+ success, _ = delete_local_image(image_info["path"])
200
+ if success:
201
+ deleted_count += 1
202
+ results.append(f"Deleted {deleted_count} local images")
203
+ except Exception as e:
204
+ results.append(f"Error clearing local images: {str(e)}")
205
+
206
+ # Force garbage collection
207
+ if force_gc:
208
+ gc.collect()
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+ torch.cuda.synchronize()
212
+ results.append("GPU cache cleared")
213
+ results.append("Garbage collection forced")
214
+
215
+ # Get memory status after cleanup
216
+ memory_status = get_memory_usage()
217
+
218
+ return {
219
+ "status": "success",
220
+ "actions_performed": results,
221
+ "memory_after_cleanup": memory_status
222
+ }
223
+
224
+ def load_model(model_name="dreamshaper-8"):
225
+ """Thread-safe model loading with HIGH-QUALITY settings and better error handling"""
226
+ global model_cache, current_model_name, current_pipe
227
+
228
+ with model_lock:
229
+ if model_name in model_cache:
230
+ current_pipe = model_cache[model_name]
231
+ current_model_name = model_name
232
+ return current_pipe
233
+
234
+ print(f"🔄 Loading HIGH-QUALITY model: {model_name}")
235
+ try:
236
+ model_id = MODEL_CHOICES.get(model_name, "lykon/dreamshaper-8")
237
+
238
+ print(f"🔧 Attempting to load: {model_id}")
239
+
240
+ pipe = StableDiffusionPipeline.from_pretrained(
241
+ model_id,
242
+ torch_dtype=torch.float32,
243
+ safety_checker=None,
244
+ requires_safety_checker=False,
245
+ local_files_only=False, # Allow downloading if not cached
246
+ cache_dir="./model_cache" # Specific cache directory
247
+ )
248
+
249
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
250
+ pipe = pipe.to("cpu")
251
+
252
+ model_cache[model_name] = pipe
253
+ current_pipe = pipe
254
+ current_model_name = model_name
255
+
256
+ print(f"✅ HIGH-QUALITY Model loaded: {model_name}")
257
+ return pipe
258
+
259
+ except Exception as e:
260
+ print(f"❌ Model loading failed for {model_name}: {e}")
261
+ print(f"🔄 Falling back to stable-diffusion-v1-5")
262
+
263
+ # Fallback to base model
264
+ try:
265
+ pipe = StableDiffusionPipeline.from_pretrained(
266
+ "runwayml/stable-diffusion-v1-5",
267
  torch_dtype=torch.float32,
268
+ safety_checker=None,
269
+ requires_safety_checker=False
270
  ).to("cpu")
271
+
272
+ model_cache[model_name] = pipe
273
+ current_pipe = pipe
274
+ current_model_name = "sd-1.5"
275
+
276
+ print(f"✅ Fallback model loaded: stable-diffusion-v1-5")
277
+ return pipe
278
+
279
+ except Exception as fallback_error:
280
+ print(f"❌ Critical: Fallback model also failed: {fallback_error}")
281
+ raise
282
 
283
+ # Initialize default model
284
+ print("🚀 Initializing Storybook Generator API...")
285
+ load_model("dreamshaper-8")
286
+ print("✅ Model loaded and ready!")
287
+
288
+ # SIMPLE PROMPT ENGINEERING - USE PURE PROMPTS ONLY
289
+ def enhance_prompt_simple(scene_visual, style="childrens_book"):
290
+ """Simple prompt enhancement - uses only the provided visual prompt with style"""
291
+
292
+ # Style templates
293
+ style_templates = {
294
+ "childrens_book": "children's book illustration, watercolor style, soft colors, whimsical, magical, storybook art, professional illustration",
295
+ "realistic": "photorealistic, detailed, natural lighting, professional photography",
296
+ "fantasy": "fantasy art, magical, ethereal, digital painting, concept art",
297
+ "anime": "anime style, Japanese animation, vibrant colors, detailed artwork"
298
+ }
299
+
300
+ style_prompt = style_templates.get(style, style_templates["childrens_book"])
301
+
302
+ # Use only the provided visual prompt with style
303
+ enhanced_prompt = f"{style_prompt}, {scene_visual}"
304
+
305
+ # Basic negative prompt for quality
306
+ negative_prompt = (
307
+ "blurry, low quality, bad anatomy, deformed characters, "
308
+ "wrong proportions, mismatched features"
309
+ )
310
+
311
+ return enhanced_prompt, negative_prompt
312
+
313
+ def generate_image_simple(prompt, model_choice, style, scene_number, consistency_seed=None):
314
+ """Generate image using pure prompts only"""
315
+
316
+ # Enhance prompt with simple style addition
317
+ enhanced_prompt, negative_prompt = enhance_prompt_simple(prompt, style)
318
+
319
+ # Use seed if provided
320
+ if consistency_seed:
321
+ scene_seed = consistency_seed + scene_number
322
+ else:
323
+ scene_seed = random.randint(1000, 9999)
324
+
325
+ try:
326
+ pipe = load_model(model_choice)
327
+
328
+ image = pipe(
329
+ prompt=enhanced_prompt,
330
+ negative_prompt=negative_prompt,
331
+ num_inference_steps=35,
332
+ guidance_scale=7.5,
333
+ width=768,
334
+ height=1024, # Portrait for better full-body
335
+ generator=torch.Generator(device="cpu").manual_seed(scene_seed)
336
+ ).images[0]
337
+
338
+ print(f"✅ Generated image for scene {scene_number}")
339
+ print(f"🌱 Seed used: {scene_seed}")
340
+ print(f"📝 Pure prompt used: {prompt}")
341
+
342
+ return image
343
+
344
+ except Exception as e:
345
+ print(f"❌ Generation failed: {str(e)}")
346
+ raise
347
+
348
+ # LOCAL FILE MANAGEMENT FUNCTIONS
349
+ def save_image_to_local(image, prompt, style="test"):
350
+ """Save image to local persistent storage"""
351
+ try:
352
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
353
+ safe_prompt = "".join(c for c in prompt[:50] if c.isalnum() or c in (' ', '-', '_')).rstrip()
354
+ filename = f"image_{safe_prompt}_{timestamp}.png"
355
+
356
+ # Create style subfolder
357
+ style_dir = os.path.join(PERSISTENT_IMAGE_DIR, style)
358
+ os.makedirs(style_dir, exist_ok=True)
359
+ filepath = os.path.join(style_dir, filename)
360
+
361
+ # Save the image
362
+ image.save(filepath)
363
+ print(f"💾 Image saved locally: {filepath}")
364
+
365
+ return filepath, filename
366
+
367
+ except Exception as e:
368
+ print(f"❌ Failed to save locally: {e}")
369
+ return None, None
370
+
371
+ def delete_local_image(filepath):
372
+ """Delete an image from local storage"""
373
+ try:
374
+ if os.path.exists(filepath):
375
+ os.remove(filepath)
376
+ print(f"🗑️ Deleted local image: {filepath}")
377
+ return True, f"✅ Deleted: {os.path.basename(filepath)}"
378
+ else:
379
+ return False, f"❌ File not found: {filepath}"
380
+ except Exception as e:
381
+ return False, f"❌ Error deleting: {str(e)}"
382
+
383
+ def get_local_storage_info():
384
+ """Get information about local storage usage"""
385
+ try:
386
+ total_size = 0
387
+ file_count = 0
388
+ images_list = []
389
+
390
+ for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR):
391
+ for file in files:
392
+ if file.endswith(('.png', '.jpg', '.jpeg')):
393
+ filepath = os.path.join(root, file)
394
+ if os.path.exists(filepath):
395
+ file_size = os.path.getsize(filepath)
396
+ total_size += file_size
397
+ file_count += 1
398
+ images_list.append({
399
+ 'path': filepath,
400
+ 'filename': file,
401
+ 'size_kb': round(file_size / 1024, 1),
402
+ 'created': os.path.getctime(filepath)
403
+ })
404
+
405
+ return {
406
+ "total_files": file_count,
407
+ "total_size_mb": round(total_size / (1024 * 1024), 2),
408
+ "images": sorted(images_list, key=lambda x: x['created'], reverse=True)
409
+ }
410
+ except Exception as e:
411
+ return {"error": str(e)}
412
+
413
+ def refresh_local_images():
414
+ """Get list of all locally saved images"""
415
+ try:
416
+ image_files = []
417
+ for root, dirs, files in os.walk(PERSISTENT_IMAGE_DIR):
418
+ for file in files:
419
+ if file.endswith(('.png', '.jpg', '.jpeg')):
420
+ filepath = os.path.join(root, file)
421
+ if os.path.exists(filepath):
422
+ image_files.append(filepath)
423
+ return image_files
424
+ except Exception as e:
425
+ print(f"Error refreshing local images: {e}")
426
+ return []
427
 
428
  # =============================================
429
+ # NEW: HUGGING FACE DATASET FUNCTIONS
430
  # =============================================
431
+
432
+ def ensure_dataset_exists():
433
+ """Create dataset if it doesn't exist"""
434
  if not HF_TOKEN:
435
+ print("⚠️ HF_TOKEN not set, cannot create/verify dataset")
436
+ return False
437
+
438
+ try:
439
+ api = HfApi(token=HF_TOKEN)
440
+ try:
441
+ api.dataset_info(DATASET_ID)
442
+ print(f"✅ Dataset {DATASET_ID} exists")
443
+ except Exception:
444
+ print(f"📦 Creating dataset: {DATASET_ID}")
445
+ api.create_repo(
446
+ repo_id=DATASET_ID,
447
+ repo_type="dataset",
448
+ private=False,
449
+ exist_ok=True
450
+ )
451
+ print(f"✅ Created dataset: {DATASET_ID}")
452
+ return True
453
+ except Exception as e:
454
+ print(f"❌ Failed to ensure dataset: {e}")
455
+ return False
456
+
457
+ def upload_to_hf_dataset(file_content, filename, subfolder=""):
458
+ """Upload a file to Hugging Face Dataset"""
459
+ if not HF_TOKEN:
460
+ print("⚠️ HF_TOKEN not set, skipping upload")
461
  return None
462
 
463
  try:
464
+ if subfolder:
465
+ path_in_repo = f"data/{subfolder}/{filename}"
466
+ else:
467
+ path_in_repo = f"data/{filename}"
 
 
 
 
 
468
 
 
469
  api = HfApi(token=HF_TOKEN)
470
  api.upload_file(
471
+ path_or_fileobj=file_content,
472
  path_in_repo=path_in_repo,
473
  repo_id=DATASET_ID,
474
  repo_type="dataset"
475
  )
476
 
477
  url = f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{path_in_repo}"
478
+ print(f"✅ Uploaded to HF Dataset: {url}")
479
  return url
480
+
481
  except Exception as e:
482
+ print(f"❌ Failed to upload to HF Dataset: {e}")
483
  return None
484
 
485
+ def upload_image_to_hf_dataset(image, project_id, page_number, prompt, style=""):
486
+ """Upload generated image to HF Dataset"""
 
 
 
487
  try:
488
+ img_bytes = io.BytesIO()
489
+ image.save(img_bytes, format='PNG')
490
+ img_data = img_bytes.getvalue()
491
 
492
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
493
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
494
+ safe_prompt = safe_prompt.replace(' ', '_')
495
+ filename = f"page_{page_number:03d}_{safe_prompt}_{timestamp}.png"
 
 
 
496
 
497
+ subfolder = f"projects/{project_id}"
498
+ url = upload_to_hf_dataset(img_data, filename, subfolder)
 
499
 
500
+ return url
 
 
 
501
 
 
 
 
 
 
502
  except Exception as e:
503
+ print(f"❌ Failed to upload image to HF Dataset: {e}")
504
+ return None
505
 
506
  # =============================================
507
+ # REMOVED: OCI BUCKET FUNCTIONS
508
+ # (save_to_oci_bucket and test_oci_connection are removed)
509
  # =============================================
 
 
 
510
 
511
+ # JOB MANAGEMENT FUNCTIONS
512
+ def create_job(story_request: StorybookRequest) -> str:
513
+ job_id = str(uuid.uuid4())
514
+
515
+ job_storage[job_id] = {
516
+ "status": JobStatus.PENDING,
517
+ "progress": 0,
518
+ "message": "Job created and queued",
519
+ "request": story_request.dict(),
520
+ "result": None,
521
+ "created_at": time.time(),
522
+ "updated_at": time.time(),
523
+ "pages": []
524
+ }
525
+
526
+ print(f"📝 Created job {job_id} for story: {story_request.story_title}")
527
+ print(f"📄 Scenes to generate: {len(story_request.scenes)}")
528
+
529
+ return job_id
530
+
531
+ def update_job_status(job_id: str, status: JobStatus, progress: int, message: str, result=None):
532
+ if job_id not in job_storage:
533
+ return False
534
+
535
+ job_storage[job_id].update({
536
+ "status": status,
537
+ "progress": progress,
538
+ "message": message,
539
+ "updated_at": time.time()
540
+ })
541
+
542
+ if result:
543
+ job_storage[job_id]["result"] = result
544
+
545
+ # Send webhook notification if callback URL exists
546
+ job_data = job_storage[job_id]
547
+ request_data = job_data["request"]
548
+
549
+ if request_data.get("callback_url"):
550
+ try:
551
+ callback_url = request_data["callback_url"]
552
+
553
+ callback_data = {
554
+ "job_id": job_id,
555
+ "status": status.value,
556
+ "progress": progress,
557
+ "message": message,
558
+ "story_title": request_data["story_title"],
559
+ "total_scenes": len(request_data["scenes"]),
560
+ "timestamp": time.time(),
561
+ "source": "huggingface-image-generator",
562
+ "estimated_time_remaining": calculate_remaining_time(job_id, progress)
563
+ }
564
+
565
+ if status == JobStatus.PROCESSING:
566
+ total_scenes = len(request_data["scenes"])
567
+ if total_scenes > 0:
568
+ current_scene = min((progress - 5) // (90 // total_scenes) + 1, total_scenes)
569
+ callback_data["current_scene"] = current_scene
570
+ callback_data["total_scenes"] = total_scenes
571
+
572
+ if current_scene <= len(request_data["scenes"]):
573
+ scene_data = request_data["scenes"][current_scene-1]
574
+ callback_data["scene_description"] = scene_data.get("visual", "")[:100] + "..."
575
+ callback_data["current_prompt"] = scene_data.get("visual", "")
576
+
577
+ if status == JobStatus.COMPLETED and result:
578
+ callback_data["result"] = {
579
+ "total_pages": result.get("total_pages", 0),
580
+ "generation_time": result.get("generation_time", 0),
581
+ "hf_dataset_url": result.get("hf_dataset_url", ""),
582
+ "pages_generated": result.get("generated_pages", 0),
583
+ "consistency_seed": result.get("consistency_seed", None),
584
+ "image_urls": result.get("image_urls", [])
585
+ }
586
+
587
+ headers = {
588
+ 'Content-Type': 'application/json',
589
+ 'User-Agent': 'Storybook-Generator/1.0'
590
+ }
591
+
592
+ print(f"📢 Sending callback to: {callback_url}")
593
+
594
+ response = requests.post(
595
+ callback_url,
596
+ json=callback_data,
597
+ headers=headers,
598
+ timeout=30
599
+ )
600
+
601
+ print(f"📢 Callback sent: Status {response.status_code}")
602
+
603
+ except Exception as e:
604
+ print(f"⚠️ Callback failed: {str(e)}")
605
+
606
+ return True
607
+
608
+ def calculate_remaining_time(job_id, progress):
609
+ """Calculate estimated time remaining"""
610
+ if progress == 0:
611
+ return "Calculating..."
612
+
613
+ job_data = job_storage.get(job_id)
614
+ if not job_data:
615
+ return "Unknown"
616
+
617
+ time_elapsed = time.time() - job_data["created_at"]
618
+ if progress > 0:
619
+ total_estimated = (time_elapsed / progress) * 100
620
+ remaining = total_estimated - time_elapsed
621
+ return f"{int(remaining // 60)}m {int(remaining % 60)}s"
622
+
623
+ return "Unknown"
624
+
625
+ # UPDATED BACKGROUND TASK - Uses HF Dataset instead of OCI
626
+ def generate_storybook_background(job_id: str):
627
+ """Background task to generate complete storybook and upload to HF Dataset"""
628
  try:
629
+ # Ensure HF Dataset exists
630
+ if HF_TOKEN:
631
+ ensure_dataset_exists()
632
+
633
+ job_data = job_storage[job_id]
634
+ story_request_data = job_data["request"]
635
+ story_request = StorybookRequest(**story_request_data)
636
+
637
+ # Use project_id from request or generate from story title
638
+ project_id = story_request.project_id or story_request.story_title.replace(' ', '_').lower()
639
+
640
+ print(f"🎬 Starting storybook generation for job {job_id}")
641
+ print(f"📖 Story: {story_request.story_title}")
642
+ print(f"📄 Scenes: {len(story_request.scenes)}")
643
+ print(f"🎨 Style: {story_request.style}")
644
+ print(f"📦 Project ID: {project_id}")
645
+
646
+ update_job_status(job_id, JobStatus.PROCESSING, 5, "Starting storybook generation with pure prompts...")
647
+
648
+ total_scenes = len(story_request.scenes)
649
+ generated_pages = []
650
+ image_urls = []
651
+ start_time = time.time()
652
+
653
+ for i, scene in enumerate(story_request.scenes):
654
+ progress = 5 + int(((i + 1) / total_scenes) * 90)
655
+
656
+ update_job_status(
657
+ job_id,
658
+ JobStatus.PROCESSING,
659
+ progress,
660
+ f"Generating page {i+1}/{total_scenes}: {scene.visual[:50]}..."
661
+ )
662
+
663
+ try:
664
+ print(f"🖼️ Generating page {i+1}")
665
+ print(f"📝 Pure prompt: {scene.visual}")
666
+
667
+ # Generate image using pure prompt only
668
+ image = generate_image_simple(
669
+ scene.visual,
670
+ story_request.model_choice,
671
+ story_request.style,
672
+ i + 1,
673
+ story_request.consistency_seed
674
+ )
675
+
676
+ # Save locally as backup
677
+ local_filepath, local_filename = save_image_to_local(image, scene.visual, story_request.style)
678
+ print(f"💾 Image saved locally as backup: {local_filename}")
679
+
680
+ # Upload to HF Dataset
681
+ hf_url = None
682
+ if HF_TOKEN:
683
+ hf_url = upload_image_to_hf_dataset(
684
+ image,
685
+ project_id,
686
+ i + 1,
687
+ scene.visual,
688
+ story_request.style
689
+ )
690
+
691
+ if hf_url:
692
+ image_urls.append(hf_url)
693
+ print(f"✅ Uploaded to HF Dataset: {hf_url}")
694
+
695
+ # Store page data
696
+ page_data = {
697
+ "page_number": i + 1,
698
+ "image_url": hf_url or f"local://{local_filepath}",
699
+ "hf_dataset_url": hf_url,
700
+ "text_content": scene.text,
701
+ "visual_description": scene.visual,
702
+ "prompt_used": scene.visual,
703
+ "local_backup_path": local_filepath
704
+ }
705
+ generated_pages.append(page_data)
706
+
707
+ print(f"✅ Page {i+1} completed")
708
+
709
+ except Exception as e:
710
+ error_msg = f"Failed to generate page {i+1}: {str(e)}"
711
+ print(f"❌ {error_msg}")
712
+ update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
713
+ return
714
+
715
+ # Complete the job
716
+ generation_time = time.time() - start_time
717
+
718
+ # Count successful HF uploads
719
+ hf_success_count = len(image_urls)
720
+ local_fallback_count = total_scenes - hf_success_count
721
+
722
+ result = {
723
+ "story_title": story_request.story_title,
724
+ "project_id": project_id,
725
+ "total_pages": total_scenes,
726
+ "generated_pages": len(generated_pages),
727
+ "generation_time": round(generation_time, 2),
728
+ "hf_dataset_url": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
729
+ "consistency_seed": story_request.consistency_seed,
730
+ "pages": generated_pages,
731
+ "image_urls": image_urls,
732
+ "upload_summary": {
733
+ "hf_successful": hf_success_count,
734
+ "local_fallback": local_fallback_count,
735
+ "total_attempted": total_scenes
736
+ }
737
+ }
738
+
739
+ status_message = f"🎉 Storybook completed! {len(generated_pages)} pages created in {generation_time:.2f}s."
740
+ if hf_success_count > 0:
741
+ status_message += f" {hf_success_count} images uploaded to HF Dataset."
742
+ if local_fallback_count > 0:
743
+ status_message += f" {local_fallback_count} pages saved locally."
744
+
745
+ update_job_status(
746
+ job_id,
747
+ JobStatus.COMPLETED,
748
+ 100,
749
+ status_message,
750
+ result
751
+ )
752
+
753
+ print(f"🎉 Storybook generation finished for job {job_id}")
754
+ print(f"📤 HF Uploads: {hf_success_count} successful, {local_fallback_count} local fallbacks")
755
+
756
+ except Exception as e:
757
+ error_msg = f"Story generation failed: {str(e)}"
758
+ print(f"❌ {error_msg}")
759
+ update_job_status(job_id, JobStatus.FAILED, 0, error_msg)
760
+
761
+ # FASTAPI ENDPOINTS (for n8n)
762
+ @app.post("/api/generate-storybook")
763
+ async def generate_storybook(request: dict, background_tasks: BackgroundTasks):
764
+ """Main endpoint for n8n integration - generates complete storybook using pure prompts"""
765
+ try:
766
+ print(f"📥 Received n8n request for story: {request.get('story_title', 'Unknown')}")
767
+
768
+ # Add consistency seed if not provided
769
+ if 'consistency_seed' not in request or not request['consistency_seed']:
770
+ request['consistency_seed'] = random.randint(1000, 9999)
771
+ print(f"🌱 Generated consistency seed: {request['consistency_seed']}")
772
+
773
+ # Generate project_id if not provided
774
+ if 'project_id' not in request:
775
+ request['project_id'] = request.get('story_title', 'unknown').replace(' ', '_').lower()
776
+
777
+ # Convert to Pydantic model
778
+ story_request = StorybookRequest(**request)
779
+
780
+ # Validate required fields
781
+ if not story_request.story_title or not story_request.scenes:
782
+ raise HTTPException(status_code=400, detail="story_title and scenes are required")
783
+
784
+ # Create job immediately
785
+ job_id = create_job(story_request)
786
+
787
+ # Start background processing
788
+ background_tasks.add_task(generate_storybook_background, job_id)
789
+
790
+ # Immediate response for n8n
791
+ response_data = {
792
  "status": "success",
793
+ "message": "Storybook generation started",
794
+ "job_id": job_id,
795
+ "story_title": story_request.story_title,
796
+ "project_id": request['project_id'],
797
+ "total_scenes": len(story_request.scenes),
798
+ "consistency_seed": story_request.consistency_seed,
799
+ "hf_dataset": f"https://huggingface.co/datasets/{DATASET_ID}" if HF_TOKEN else None,
800
+ "callback_url": story_request.callback_url,
801
+ "estimated_time_seconds": len(story_request.scenes) * 35,
802
+ "timestamp": datetime.now().isoformat()
803
  }
804
+
805
+ print(f"✅ Job {job_id} started for: {story_request.story_title}")
806
+
807
+ return response_data
808
+
809
  except Exception as e:
810
+ error_msg = f"API Error: {str(e)}"
811
+ print(f"❌ {error_msg}")
812
+ raise HTTPException(status_code=500, detail=error_msg)
813
 
814
+ @app.get("/api/job-status/{job_id}")
815
+ async def get_job_status_endpoint(job_id: str):
816
+ """Check job status"""
817
+ job_data = job_storage.get(job_id)
818
+ if not job_data:
819
+ raise HTTPException(status_code=404, detail="Job not found")
820
+
821
+ return JobStatusResponse(
822
+ job_id=job_id,
823
+ status=job_data["status"],
824
+ progress=job_data["progress"],
825
+ message=job_data["message"],
826
+ result=job_data["result"],
827
+ created_at=job_data["created_at"],
828
+ updated_at=job_data["updated_at"]
829
+ )
830
+
831
+ @app.get("/api/health")
832
+ async def api_health():
833
+ """Health check endpoint for n8n"""
834
  return {
835
  "status": "healthy",
836
+ "service": "storybook-generator",
837
+ "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
838
+ "hf_token_set": bool(HF_TOKEN),
839
+ "timestamp": datetime.now().isoformat(),
840
+ "active_jobs": len(job_storage),
841
+ "models_loaded": list(model_cache.keys())
842
  }
843
 
844
+ # NEW: Endpoint to get project images from HF Dataset
845
+ @app.get("/api/project-images/{project_id}")
846
+ async def get_project_images(project_id: str):
847
+ """Get all images for a project from HF Dataset"""
848
+ try:
849
+ if not HF_TOKEN:
850
+ return {"error": "HF_TOKEN not set"}
851
+
852
+ api = HfApi(token=HF_TOKEN)
853
+ files = api.list_repo_files(repo_id=DATASET_ID, repo_type="dataset")
854
+
855
+ project_files = [f for f in files if f.startswith(f"data/projects/{project_id}/")]
856
+
857
+ urls = [f"https://huggingface.co/datasets/{DATASET_ID}/resolve/main/{f}" for f in project_files]
858
+
859
+ return {
860
+ "project_id": project_id,
861
+ "total_images": len(urls),
862
+ "image_urls": urls
863
+ }
864
+ except Exception as e:
865
+ return {"error": str(e)}
866
 
867
+ # NEW MEMORY MANAGEMENT ENDPOINTS
868
+ @app.get("/api/memory-status")
869
+ async def get_memory_status():
870
+ """Get current memory usage and system status"""
871
+ memory_info = get_memory_usage()
872
+ return MemoryStatusResponse(
873
+ memory_used_mb=memory_info["memory_used_mb"],
874
+ memory_percent=memory_info["memory_percent"],
875
+ models_loaded=memory_info["models_loaded"],
876
+ active_jobs=memory_info["active_jobs"],
877
+ local_images_count=memory_info["local_images_count"],
878
+ gpu_memory_allocated_mb=memory_info["gpu_memory_allocated_mb"],
879
+ gpu_memory_cached_mb=memory_info["gpu_memory_cached_mb"],
880
+ status="healthy"
881
+ )
882
+
883
+ @app.post("/api/clear-memory")
884
+ async def clear_memory_endpoint(request: MemoryClearanceRequest):
885
+ """Clear memory by unloading models and cleaning up resources"""
886
+ try:
887
+ result = clear_memory(
888
+ clear_models=request.clear_models,
889
+ clear_jobs=request.clear_jobs,
890
+ clear_local_images=request.clear_local_images,
891
+ force_gc=request.force_gc
892
+ )
893
+
894
+ return {
895
+ "status": "success",
896
+ "message": "Memory clearance completed",
897
+ "details": result
898
+ }
899
 
900
+ except Exception as e:
901
+ raise HTTPException(status_code=500, detail=f"Memory clearance failed: {str(e)}")
902
+
903
+ @app.post("/api/auto-cleanup")
904
+ async def auto_cleanup():
905
+ """Automatic cleanup - clears completed jobs and forces GC"""
906
+ try:
907
+ result = clear_memory(
908
+ clear_models=False, # Don't clear models by default
909
+ clear_jobs=True, # Clear completed jobs
910
+ clear_local_images=False, # Don't clear images by default
911
+ force_gc=True # Force garbage collection
912
+ )
913
+
914
+ return {
915
+ "status": "success",
916
+ "message": "Automatic cleanup completed",
917
+ "details": result
918
+ }
919
+
920
+ except Exception as e:
921
+ raise HTTPException(status_code=500, detail=f"Auto cleanup failed: {str(e)}")
922
+
923
+ @app.get("/api/local-images")
924
+ async def get_local_images():
925
+ """API endpoint to get locally saved test images"""
926
+ storage_info = get_local_storage_info()
927
+ return storage_info
928
+
929
+ @app.delete("/api/local-images/{filename:path}")
930
+ async def delete_local_image_api(filename: str):
931
+ """API endpoint to delete a local image"""
932
+ try:
933
+ filepath = os.path.join(PERSISTENT_IMAGE_DIR, filename)
934
+ success, message = delete_local_image(filepath)
935
+ return {"status": "success" if success else "error", "message": message}
936
+ except Exception as e:
937
+ return {"status": "error", "message": str(e)}
938
+
939
+ # SIMPLE GRADIO INTERFACE
940
+ def create_gradio_interface():
941
+ """Create simple Gradio interface for testing"""
942
 
943
+ def generate_test_image_simple(prompt, model_choice, style_choice):
944
+ """Generate a single image using pure prompt only"""
945
+ try:
946
+ if not prompt.strip():
947
+ return None, "❌ Please enter a prompt", None
948
+
949
+ print(f"🎨 Generating test image with pure prompt: {prompt}")
950
+
951
+ # Generate the image using pure prompt
952
+ image = generate_image_simple(
953
+ prompt,
954
+ model_choice,
955
+ style_choice,
956
+ 1
957
+ )
958
+
959
+ # Save to local storage
960
+ filepath, filename = save_image_to_local(image, prompt, style_choice)
961
+
962
+ status_msg = f"""✅ Success! Generated: {prompt}
963
+
964
+ 📁 **Local file:** {filename if filename else 'Not saved'}"""
965
+
966
+ return image, status_msg, filepath
967
+
968
+ except Exception as e:
969
+ error_msg = f"❌ Generation failed: {str(e)}"
970
+ print(error_msg)
971
+ return None, error_msg, None
972
 
973
+ with gr.Blocks(title="Simple Image Generator", theme="soft") as demo:
974
+ gr.Markdown("# 🎨 Simple Image Generator")
975
+ gr.Markdown("Generate images using **pure prompts only** - no automatic enhancements")
976
+
977
+ # Storage info display
978
+ storage_info = gr.Textbox(
979
+ label="📊 Local Storage Information",
980
+ interactive=False,
981
+ lines=2
982
+ )
983
+
984
+ # Memory status display
985
+ memory_status = gr.Textbox(
986
+ label="🧠 Memory Status",
987
+ interactive=False,
988
+ lines=3
989
+ )
990
+
991
+ # HF Dataset status
992
+ hf_status = gr.Textbox(
993
+ label="📤 Hugging Face Dataset",
994
+ value=f"✅ Connected to {DATASET_ID}" if HF_TOKEN else "❌ HF_TOKEN not set - local only",
995
+ interactive=False,
996
+ lines=2
997
+ )
998
+
999
+ def update_storage_info():
1000
+ info = get_local_storage_info()
1001
+ if "error" not in info:
1002
+ return f"📁 Local Storage: {info['total_files']} images, {info['total_size_mb']} MB used"
1003
+ return "📁 Local Storage: Unable to calculate"
1004
+
1005
+ def update_memory_status():
1006
+ memory_info = get_memory_usage()
1007
+ status_text = f"🧠 Memory Usage: {memory_info['memory_used_mb']} MB ({memory_info['memory_percent']}%)\n"
1008
+ status_text += f"📦 Models Loaded: {memory_info['models_loaded']}\n"
1009
+ status_text += f"⚡ Active Jobs: {memory_info['active_jobs']}"
1010
+
1011
+ if memory_info['gpu_memory_allocated_mb']:
1012
+ status_text += f"\n🎮 GPU Memory: {memory_info['gpu_memory_allocated_mb']} MB allocated"
1013
+
1014
+ return status_text
1015
+
1016
+ with gr.Row():
1017
+ with gr.Column(scale=1):
1018
+ gr.Markdown("### 🎯 Quality Settings")
1019
+
1020
+ model_dropdown = gr.Dropdown(
1021
+ label="AI Model",
1022
+ choices=list(MODEL_CHOICES.keys()),
1023
+ value="dreamshaper-8"
1024
+ )
1025
+
1026
+ style_dropdown = gr.Dropdown(
1027
+ label="Art Style",
1028
+ choices=["childrens_book", "realistic", "fantasy", "anime"],
1029
+ value="anime"
1030
+ )
1031
+
1032
+ prompt_input = gr.Textbox(
1033
+ label="Pure Prompt",
1034
+ placeholder="Enter your exact prompt...",
1035
+ lines=3
1036
+ )
1037
+
1038
+ generate_btn = gr.Button("✨ Generate Image", variant="primary")
1039
+
1040
+ # Current image management
1041
+ current_file_path = gr.State()
1042
+ delete_btn = gr.Button("🗑️ Delete This Image", variant="stop")
1043
+ delete_status = gr.Textbox(label="Delete Status", interactive=False, lines=2)
1044
+
1045
+ # Memory management section
1046
+ gr.Markdown("### 🧠 Memory Management")
1047
+ with gr.Row():
1048
+ auto_cleanup_btn = gr.Button("🔄 Auto Cleanup", size="sm")
1049
+ clear_models_btn = gr.Button("🗑️ Clear Models", variant="stop", size="sm")
1050
+
1051
+ memory_clear_status = gr.Textbox(label="Memory Clear Status", interactive=False, lines=2)
1052
+
1053
+ gr.Markdown("### 📚 API Usage for n8n")
1054
+ gr.Markdown(f"""
1055
+ **Generate Storybook:**
1056
+ - Endpoint: `POST /api/generate-storybook`
1057
+ - Body: `{{"story_title": "...", "scenes": [...]}}`
1058
+
1059
+ **Check Status:**
1060
+ - `GET /api/job-status/{{job_id}}`
1061
+
1062
+ **HF Dataset:**
1063
+ - `{DATASET_ID if HF_TOKEN else "Set HF_TOKEN to enable"}`
1064
+ """)
1065
+
1066
+ with gr.Column(scale=2):
1067
+ image_output = gr.Image(label="Generated Image", height=500, show_download_button=True)
1068
+ status_output = gr.Textbox(label="Status", interactive=False, lines=4)
1069
+
1070
+ # Local file management section
1071
+ with gr.Accordion("📁 Manage Local Test Images", open=True):
1072
+ gr.Markdown("### Locally Saved Images")
1073
+
1074
+ with gr.Row():
1075
+ refresh_btn = gr.Button("🔄 Refresh List")
1076
+ clear_all_btn = gr.Button("🗑️ Clear All Images", variant="stop")
1077
+
1078
+ file_gallery = gr.Gallery(
1079
+ label="Local Images",
1080
+ show_label=True,
1081
+ elem_id="gallery",
1082
+ columns=4,
1083
+ height="auto"
1084
+ )
1085
+
1086
+ clear_status = gr.Textbox(label="Clear Status", interactive=False)
1087
+
1088
+ def delete_current_image(filepath):
1089
+ """Delete the currently displayed image"""
1090
+ if not filepath:
1091
+ return "❌ No image to delete", None, None, refresh_local_images()
1092
+
1093
+ success, message = delete_local_image(filepath)
1094
+ updated_files = refresh_local_images()
1095
+
1096
+ if success:
1097
+ status_msg = f"✅ {message}"
1098
+ return status_msg, None, "Image deleted successfully!", updated_files
1099
+ else:
1100
+ return f"❌ {message}", None, "Delete failed", updated_files
1101
 
1102
+ def clear_all_images():
1103
+ """Delete all local images"""
1104
+ try:
1105
+ storage_info = get_local_storage_info()
1106
+ deleted_count = 0
1107
+
1108
+ if "images" in storage_info:
1109
+ for image_info in storage_info["images"]:
1110
+ success, _ = delete_local_image(image_info["path"])
1111
+ if success:
1112
+ deleted_count += 1
1113
+
1114
+ updated_files = refresh_local_images()
1115
+ return f"✅ Deleted {deleted_count} images", updated_files
1116
+ except Exception as e:
1117
+ return f"❌ Error: {str(e)}", refresh_local_images()
1118
+
1119
+ def perform_auto_cleanup():
1120
+ """Perform automatic cleanup"""
1121
+ try:
1122
+ result = clear_memory(
1123
+ clear_models=False,
1124
+ clear_jobs=True,
1125
+ clear_local_images=False,
1126
+ force_gc=True
1127
+ )
1128
+ return f"✅ Auto cleanup completed: {len(result['actions_performed'])} actions"
1129
+ except Exception as e:
1130
+ return f"❌ Auto cleanup failed: {str(e)}"
1131
+
1132
+ def clear_models():
1133
+ """Clear all loaded models"""
1134
+ try:
1135
+ result = clear_memory(
1136
+ clear_models=True,
1137
+ clear_jobs=False,
1138
+ clear_local_images=False,
1139
+ force_gc=True
1140
+ )
1141
+ return f"✅ Models cleared: {len(result['actions_performed'])} actions"
1142
+ except Exception as e:
1143
+ return f"❌ Model clearance failed: {str(e)}"
1144
+
1145
+ # Connect buttons to functions
1146
+ generate_btn.click(
1147
+ fn=generate_test_image_simple,
1148
+ inputs=[prompt_input, model_dropdown, style_dropdown],
1149
+ outputs=[image_output, status_output, current_file_path]
1150
+ ).then(
1151
+ fn=refresh_local_images,
1152
+ outputs=file_gallery
1153
+ ).then(
1154
+ fn=update_storage_info,
1155
+ outputs=storage_info
1156
+ ).then(
1157
+ fn=update_memory_status,
1158
+ outputs=memory_status
1159
+ )
1160
+
1161
+ delete_btn.click(
1162
+ fn=delete_current_image,
1163
+ inputs=current_file_path,
1164
+ outputs=[delete_status, image_output, status_output, file_gallery]
1165
+ ).then(
1166
+ fn=update_storage_info,
1167
+ outputs=storage_info
1168
+ ).then(
1169
+ fn=update_memory_status,
1170
+ outputs=memory_status
1171
+ )
1172
+
1173
+ refresh_btn.click(
1174
+ fn=refresh_local_images,
1175
+ outputs=file_gallery
1176
+ ).then(
1177
+ fn=update_storage_info,
1178
+ outputs=storage_info
1179
+ ).then(
1180
+ fn=update_memory_status,
1181
+ outputs=memory_status
1182
+ )
1183
+
1184
+ clear_all_btn.click(
1185
+ fn=clear_all_images,
1186
+ outputs=[clear_status, file_gallery]
1187
+ ).then(
1188
+ fn=update_storage_info,
1189
+ outputs=storage_info
1190
+ ).then(
1191
+ fn=update_memory_status,
1192
+ outputs=memory_status
1193
+ )
1194
+
1195
+ # Memory management buttons
1196
+ auto_cleanup_btn.click(
1197
+ fn=perform_auto_cleanup,
1198
+ outputs=memory_clear_status
1199
+ ).then(
1200
+ fn=update_memory_status,
1201
+ outputs=memory_status
1202
+ )
1203
+
1204
+ clear_models_btn.click(
1205
+ fn=clear_models,
1206
+ outputs=memory_clear_status
1207
+ ).then(
1208
+ fn=update_memory_status,
1209
+ outputs=memory_status
1210
+ )
1211
+
1212
+ # Initialize on load
1213
+ demo.load(fn=refresh_local_images, outputs=file_gallery)
1214
+ demo.load(fn=update_storage_info, outputs=storage_info)
1215
+ demo.load(fn=update_memory_status, outputs=memory_status)
1216
+
1217
+ return demo
1218
+
1219
+ # Create simple Gradio app
1220
+ demo = create_gradio_interface()
1221
+
1222
+ # Simple root endpoint
1223
  @app.get("/")
1224
  async def root():
1225
  return {
1226
+ "message": "Storybook Generator API with HF Dataset is running!",
 
1227
  "api_endpoints": {
1228
+ "health_check": "GET /api/health",
1229
+ "generate_storybook": "POST /api/generate-storybook",
1230
+ "check_job_status": "GET /api/job-status/{job_id}",
1231
+ "project_images": "GET /api/project-images/{project_id}",
1232
+ "local_images": "GET /api/local-images",
1233
+ "memory_status": "GET /api/memory-status",
1234
+ "clear_memory": "POST /api/clear-memory",
1235
+ "auto_cleanup": "POST /api/auto-cleanup"
1236
+ },
1237
+ "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
1238
+ "features": {
1239
+ "pure_prompts": "✅ Enabled - No automatic enhancements",
1240
+ "n8n_integration": "✅ Enabled",
1241
+ "memory_management": "✅ Enabled",
1242
+ "hf_dataset": "✅ Enabled" if HF_TOKEN else "❌ Disabled"
1243
  },
1244
+ "web_interface": "GET /ui"
 
1245
  }
1246
 
1247
+ # Add a simple test endpoint
1248
+ @app.get("/api/test")
1249
+ async def test_endpoint():
1250
+ return {
1251
+ "status": "success",
1252
+ "message": "API is working correctly",
1253
+ "hf_dataset": DATASET_ID if HF_TOKEN else "Disabled",
1254
+ "timestamp": datetime.now().isoformat()
1255
+ }
1256
+
1257
+ # For Hugging Face Spaces deployment
1258
+ def get_app():
1259
+ return app
1260
+
1261
  if __name__ == "__main__":
1262
  import uvicorn
1263
+ import os
1264
 
1265
+ # Check if we're running on Hugging Face Spaces
1266
+ HF_SPACE = os.environ.get('SPACE_ID') is not None
 
 
 
 
 
 
1267
 
1268
+ if HF_SPACE:
1269
+ print("🚀 Running on Hugging Face Spaces - Integrated Mode")
1270
+ print(f"📦 HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
1271
+ print("📚 API endpoints available at: /api/*")
1272
+ print("🎨 Web interface available at: /ui")
1273
+ print("📝 PURE PROMPTS enabled - no automatic enhancements")
1274
+ print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available")
1275
+
1276
+ # Mount Gradio without reassigning app
1277
+ gr.mount_gradio_app(app, demo, path="/ui")
1278
+
1279
+ # Run the combined app
1280
+ uvicorn.run(
1281
+ app,
1282
+ host="0.0.0.0",
1283
+ port=7860,
1284
+ log_level="info"
1285
+ )
1286
+ else:
1287
+ # Local development - run separate servers
1288
+ print("🚀 Running locally - Separate API and UI servers")
1289
+ print(f"📦 HF Dataset: {DATASET_ID if HF_TOKEN else 'Disabled'}")
1290
+ print("📚 API endpoints: http://localhost:8000/api/*")
1291
+ print("🎨 Web interface: http://localhost:7860/ui")
1292
+ print("📝 PURE PROMPTS enabled - no automatic enhancements")
1293
+ print("🧠 MEMORY MANAGEMENT enabled - automatic cleanup available")
1294
+
1295
+ def run_fastapi():
1296
+ """Run FastAPI on port 8000 for API calls"""
1297
+ uvicorn.run(
1298
+ app,
1299
+ host="0.0.0.0",
1300
+ port=8000,
1301
+ log_level="info",
1302
+ access_log=False
1303
+ )
1304
+
1305
+ def run_gradio():
1306
+ """Run Gradio on port 7860 for web interface"""
1307
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
1308
+
1309
+ # Run both servers in separate threads
1310
+ import threading
1311
+ fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
1312
+ gradio_thread = threading.Thread(target=run_gradio, daemon=True)
1313
+
1314
+ fastapi_thread.start()
1315
+ gradio_thread.start()
1316
+
1317
+ try:
1318
+ # Keep main thread alive
1319
+ while True:
1320
+ time.sleep(1)
1321
+ except KeyboardInterrupt:
1322
+ print("🛑 Shutting down servers...")