yukee1992 commited on
Commit
adfd61d
Β·
verified Β·
1 Parent(s): a2206ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -244
app.py CHANGED
@@ -1,290 +1,270 @@
1
- # Import necessary libraries
2
  import torch
3
- from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
4
- from fastapi import FastAPI, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from pydantic import BaseModel
7
- import io
8
- import base64
9
  from PIL import Image
10
- import time
11
- from datetime import datetime
12
  import os
13
- from fastapi import Request
14
- from fastapi.responses import HTMLResponse
15
-
16
-
17
- # Google Drive imports
18
- from google.oauth2 import service_account
19
- from googleapiclient.discovery import build
20
- from googleapiclient.http import MediaIoBaseUpload
21
- import json
22
-
23
- # Initialize FastAPI
24
- app = FastAPI(title="Children's Book Illustrator API")
25
-
26
- # Add CORS middleware to allow requests from n8n
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"],
30
- allow_methods=["*"],
31
- allow_headers=["*"],
32
- )
33
-
34
- # Force CPU usage
35
- device = "cpu"
36
- print(f"Using device: {device}")
37
-
38
- # Load model
39
- model_id = "stabilityai/stable-diffusion-2-1"
40
- print("Loading pipeline... This may take a few minutes.")
41
 
 
42
  try:
43
- pipe = StableDiffusionPipeline.from_pretrained(
44
- model_id,
45
- torch_dtype=torch.float32,
46
- use_safetensors=True,
47
- safety_checker=None,
48
- requires_safety_checker=False
49
- )
50
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
51
- pipe = pipe.to(device)
52
- print("Model loaded successfully on CPU!")
53
- except Exception as e:
54
- print(f"Error loading model: {e}")
55
- # Fallback
56
- model_id = "dreamlike-art/dreamlike-diffusion-1.0"
57
- pipe = StableDiffusionPipeline.from_pretrained(
58
- model_id,
59
- torch_dtype=torch.float32,
60
- use_safetensors=True,
61
- safety_checker=None,
62
- requires_safety_checker=False
63
- )
64
- pipe = pipe.to(device)
65
- print(f"Fell back to {model_id}")
66
 
67
- # Google Drive Setup
68
- def setup_google_drive():
69
- """Initialize Google Drive service"""
 
70
  try:
71
- # Get service account credentials from environment variable
72
- credentials_json = os.getenv('GOOGLE_SERVICE_ACCOUNT_JSON')
73
- if not credentials_json:
74
- print("Google Drive: No service account credentials found")
75
- return None
76
-
77
- # Parse the JSON credentials
78
- service_account_info = json.loads(credentials_json)
79
- credentials = service_account.Credentials.from_service_account_info(
80
- service_account_info,
81
- scopes=['https://www.googleapis.com/auth/drive.file']
82
- )
83
-
84
- # Build the Drive service
85
- drive_service = build('drive', 'v3', credentials=credentials)
86
- print("Google Drive service initialized successfully")
87
- return drive_service
88
-
89
  except Exception as e:
90
- print(f"Google Drive setup failed: {e}")
91
- return None
92
 
93
- # Google Drive Setup with DEBUGGING
94
- def setup_google_drive():
95
- """Initialize Google Drive service with detailed debugging"""
 
 
96
  try:
97
- print("Setting up Google Drive...")
98
-
99
- # Get service account credentials from environment variable
100
- credentials_json = os.getenv('GOOGLE_SERVICE_ACCOUNT_JSON')
101
- if not credentials_json:
102
- print("❌ ERROR: GOOGLE_SERVICE_ACCOUNT_JSON environment variable not found")
103
- return None
104
-
105
- print("βœ… Found Google service account JSON")
106
 
107
- # Get Shared Drive ID
108
- SHARED_DRIVE_ID = os.getenv('SHARED_DRIVE_ID')
109
- if not SHARED_DRIVE_ID:
110
- print("❌ ERROR: SHARED_DRIVE_ID environment variable not set")
111
- return None
112
-
113
- print(f"βœ… Shared Drive ID: {SHARED_DRIVE_ID}")
114
 
115
- # Parse the JSON credentials
116
- try:
117
- service_account_info = json.loads(credentials_json)
118
- client_email = service_account_info.get('client_email', 'Unknown')
119
- print(f"βœ… Service account email: {client_email}")
120
- except json.JSONDecodeError as e:
121
- print(f"❌ ERROR: Invalid JSON in service account credentials: {e}")
122
- return None
123
 
124
- credentials = service_account.Credentials.from_service_account_info(
125
- service_account_info,
126
- scopes=['https://www.googleapis.com/auth/drive.file']
127
- )
128
 
129
- # Build the Drive service
130
- drive_service = build('drive', 'v3', credentials=credentials)
131
- print("βœ… Google Drive service initialized successfully")
132
- return drive_service
133
 
 
 
 
 
 
134
  except Exception as e:
135
- print(f"❌ Google Drive setup failed: {str(e)}")
136
- import traceback
137
- traceback.print_exc()
138
- return None
139
-
140
- # Initialize Google Drive service
141
- drive_service = setup_google_drive()
142
- SHARED_DRIVE_ID = os.getenv('SHARED_DRIVE_ID')
143
 
144
- def save_to_google_drive(image, prompt):
145
- """Save image to Google Drive Shared Drive with detailed debugging"""
146
- if not drive_service:
147
- print("❌ Google Drive service not available, skipping save")
148
- return None
149
-
150
- if not SHARED_DRIVE_ID:
151
- print("❌ Shared Drive ID not configured, skipping save")
152
- return None
153
-
154
  try:
155
- print(f"πŸ”„ Attempting to save image to Shared Drive: {SHARED_DRIVE_ID}")
156
-
157
- # Create a filename with timestamp
158
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
159
- safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
160
- filename = f"storybook_{timestamp}_{safe_prompt}.png"
161
-
162
- print(f"πŸ“ Filename: {filename}")
163
-
164
  # Convert image to bytes
165
  img_bytes = io.BytesIO()
166
  image.save(img_bytes, format='PNG')
167
  img_bytes.seek(0)
168
 
169
- # Create file metadata
170
- file_metadata = {
171
- 'name': filename,
172
- 'mimeType': 'image/png',
173
- 'parents': [SHARED_DRIVE_ID] # Save to Shared Drive
174
- }
 
 
175
 
176
- print(f"πŸ“‹ File metadata: {file_metadata}")
 
 
 
 
 
 
 
177
 
178
- # Upload to Google Drive with supportsAllDrives=True
179
- media = MediaIoBaseUpload(img_bytes, mimetype='image/png', resumable=True)
 
 
180
 
181
- print("⬆️ Starting upload to Google Drive...")
182
- file = drive_service.files().create(
183
- body=file_metadata,
184
- media_body=media,
185
- supportsAllDrives=True, # CRITICAL FOR SHARED DRIVES
186
- fields='id, webViewLink'
187
- ).execute()
188
 
189
- drive_link = file.get('webViewLink')
190
- print(f"βœ… Image saved to Google Drive: {drive_link}")
191
- return drive_link
192
 
 
 
 
 
 
 
 
 
 
193
  except Exception as e:
194
- print(f"❌ Failed to save to Google Drive: {str(e)}")
195
- import traceback
196
- traceback.print_exc()
197
- return None
198
 
199
- # Add this simple OAuth callback handler
200
- @app.get("/oauth2callback", response_class=HTMLResponse)
201
- async def oauth2_callback(request: Request):
202
- """
203
- Simple endpoint to handle OAuth2 redirect and display the authorization code
204
- """
205
- code = request.query_params.get("code")
206
- if code:
207
- # Display the code in a simple HTML page
208
- html_content = f"""
209
- <html>
210
- <head><title>Authentication Successful</title></head>
211
- <body>
212
- <h2>βœ… Authentication Successful!</h2>
213
- <p>Your authorization code has been received.</p>
214
- <p>Please copy this code and paste it back into Termux:</p>
215
- <div style="background: #f0f0f0; padding: 10px; border-radius: 5px; word-break: break-all;">
216
- <strong>{code}</strong>
217
- </div>
218
- <p><br>Then press Enter to complete the process.</p>
219
- </body>
220
- </html>
221
- """
222
- return HTMLResponse(content=html_content)
223
  else:
224
- return HTMLResponse(content="<h2>❌ No authorization code received</h2>")
225
-
226
- # Request model
227
- class GenerateRequest(BaseModel):
228
- prompt: str
229
- width: int = 512
230
- height: int = 512
231
- steps: int = 25
232
- save_to_drive: bool = True # New option to control saving
233
-
234
- # Health check endpoint
235
- @app.get("/")
236
- async def health_check():
237
- drive_status = "connected" if drive_service else "disconnected"
238
- return {"status": "healthy", "model": model_id, "google_drive": drive_status}
239
 
240
- # Main API endpoint
241
- @app.post("/generate")
242
- async def generate_image(request: GenerateRequest):
243
  try:
244
- # Enhanced prompt
245
- enhanced_prompt = f"masterpiece, best quality, 4K, ultra detailed, photorealistic, sharp focus, studio lighting, professional photography, {request.prompt}"
246
- negative_prompt = "blurry, low quality, low resolution, watermark, signature, text, ugly, deformed"
247
 
248
- print(f"Generating image for prompt: {enhanced_prompt}")
249
 
250
  # Generate image
251
  image = pipe(
252
- prompt=enhanced_prompt,
253
- negative_prompt=negative_prompt,
254
- width=request.width,
255
- height=request.height,
256
- guidance_scale=9.0,
257
- num_inference_steps=request.steps,
258
- generator=torch.Generator(device=device)
259
  ).images[0]
260
 
261
- if image.mode != 'RGB':
262
- image = image.convert('RGB')
263
 
264
- print("Image generated successfully!")
 
 
 
265
 
266
- # Save to Google Drive if enabled
267
- drive_link = None
268
- if request.save_to_drive and drive_service:
269
- drive_link = save_to_google_drive(image, request.prompt)
270
 
271
- # Convert to base64 for API response
272
- buffered = io.BytesIO()
273
- image.save(buffered, format="PNG")
274
- img_base64 = base64.b64encode(buffered.getvalue()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- return {
277
- "status": "success",
278
- "image": f"data:image/png;base64,{img_base64}",
279
- "prompt": request.prompt,
280
- "google_drive_link": drive_link,
281
- "saved_to_drive": drive_link is not None
282
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- except Exception as e:
285
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- # Run the app
288
  if __name__ == "__main__":
289
- import uvicorn
290
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
 
 
 
 
 
4
  from PIL import Image
5
+ import io
6
+ import requests
7
  import os
8
+ from datetime import datetime
9
+ import re
10
+ import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Try to import your existing OCI connector for direct access
13
  try:
14
+ # This will work if we're in the same app context
15
+ from app import oci_connector
16
+ DIRECT_OCI_ACCESS = True
17
+ print("βœ… Direct OCI access available - using existing OCI connector")
18
+ except ImportError:
19
+ DIRECT_OCI_ACCESS = False
20
+ print("⚠️ Direct OCI access not available - using API endpoint")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Initialize the Stable Diffusion model
23
+ def load_model():
24
+ """Load and return the Stable Diffusion model"""
25
+ print("πŸ”„ Loading Stable Diffusion model...")
26
  try:
27
+ pipe = StableDiffusionPipeline.from_pretrained(
28
+ "runwayml/stable-diffusion-v1-5",
29
+ torch_dtype=torch.float32,
30
+ safety_checker=None, # Disable for better performance
31
+ requires_safety_checker=False
32
+ ).to("cpu")
33
+ print("βœ… Model loaded successfully!")
34
+ return pipe
 
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
+ print(f"❌ Model loading failed: {e}")
37
+ raise e
38
 
39
+ # Load the model once at startup
40
+ pipe = load_model()
41
+
42
+ def save_to_oci_direct(image, prompt):
43
+ """Save image using direct OCI connector access"""
44
  try:
45
+ # Create temporary file
46
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
47
+ image.save(tmp, format='PNG')
48
+ temp_path = tmp.name
 
 
 
 
 
49
 
50
+ # Create organized filename
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
53
+ filename = f"story_{timestamp}_{safe_prompt}.png"
 
 
 
54
 
55
+ # Use project-based directory structure
56
+ object_name = f"storybook-generator/childrens-books/{filename}"
 
 
 
 
 
 
57
 
58
+ # Upload using existing OCI connector
59
+ success, message = oci_connector.upload_file(temp_path, object_name, None)
 
 
60
 
61
+ # Clean up temporary file
62
+ os.unlink(temp_path)
 
 
63
 
64
+ if success:
65
+ return f"βœ… {message}"
66
+ else:
67
+ return f"❌ {message}"
68
+
69
  except Exception as e:
70
+ return f"❌ Direct upload failed: {str(e)}"
 
 
 
 
 
 
 
71
 
72
+ def save_to_oci_via_api(image, prompt):
73
+ """Save image using the OCI API endpoint"""
 
 
 
 
 
 
 
 
74
  try:
 
 
 
 
 
 
 
 
 
75
  # Convert image to bytes
76
  img_bytes = io.BytesIO()
77
  image.save(img_bytes, format='PNG')
78
  img_bytes.seek(0)
79
 
80
+ # Create filename
81
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
82
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
83
+ filename = f"story_{timestamp}_{safe_prompt}.png"
84
+
85
+ # Your OCI API endpoint URL
86
+ # Use relative URL since we're in the same space
87
+ api_url = "/api/upload"
88
 
89
+ # For Hugging Face deployment, we need to handle different URL formats
90
+ try:
91
+ # Try to get the space URL from environment
92
+ space_name = os.environ.get('SPACE_NAME', 'yukee1992-oci-video-storage')
93
+ api_url = f"https://{space_name}.hf.space/api/upload"
94
+ except:
95
+ # Fallback to relative URL
96
+ pass
97
 
98
+ # Prepare form data for API request
99
+ files = {
100
+ 'file': (filename, img_bytes.getvalue(), 'image/png')
101
+ }
102
 
103
+ data = {
104
+ 'project_id': 'storybook-generator',
105
+ 'subfolder': 'childrens-books'
106
+ }
 
 
 
107
 
108
+ # Make the API request
109
+ response = requests.post(api_url, files=files, data=data)
 
110
 
111
+ if response.status_code == 200:
112
+ result = response.json()
113
+ if result['status'] == 'success':
114
+ return f"βœ… {result['message']}"
115
+ else:
116
+ return f"❌ API Error: {result.get('message', 'Unknown error')}"
117
+ else:
118
+ return f"❌ HTTP Error: {response.status_code} - {response.text}"
119
+
120
  except Exception as e:
121
+ return f"❌ API upload failed: {str(e)}"
 
 
 
122
 
123
+ def save_to_oci(image, prompt):
124
+ """Main function to save image to OCI using best available method"""
125
+ if DIRECT_OCI_ACCESS:
126
+ return save_to_oci_direct(image, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
+ return save_to_oci_via_api(image, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ def generate_storybook_image(prompt):
131
+ """Generate an image from text prompt and save to OCI"""
 
132
  try:
133
+ # Enhance the prompt for better children's book style
134
+ enhanced_prompt = f"children's book illustration, colorful, whimsical, cute, {prompt}"
 
135
 
136
+ print(f"🎨 Generating image for prompt: {enhanced_prompt}")
137
 
138
  # Generate image
139
  image = pipe(
140
+ enhanced_prompt,
141
+ num_inference_steps=20, # Faster generation
142
+ guidance_scale=7.5
 
 
 
 
143
  ).images[0]
144
 
145
+ print("βœ… Image generated successfully!")
 
146
 
147
+ # Save to OCI
148
+ print("πŸ’Ύ Saving image to OCI storage...")
149
+ save_status = save_to_oci(image, prompt)
150
+ print(save_status)
151
 
152
+ return image, save_status
 
 
 
153
 
154
+ except Exception as e:
155
+ error_msg = f"❌ Generation failed: {str(e)}"
156
+ print(error_msg)
157
+ return None, error_msg
158
+
159
+ def batch_generate_storybook(scenes):
160
+ """Generate multiple images for a storybook"""
161
+ if not scenes:
162
+ return [], "❌ Please provide at least one scene"
163
+
164
+ results = []
165
+ status_messages = []
166
+
167
+ for i, scene in enumerate(scenes):
168
+ if not scene.strip():
169
+ continue
170
+
171
+ print(f"πŸ“– Generating scene {i+1}/{len(scenes)}: {scene}")
172
+ image, status = generate_storybook_image(scene)
173
 
174
+ if image:
175
+ results.append((f"Scene {i+1}: {scene}", image))
176
+ status_messages.append(f"Scene {i+1}: {status}")
177
+
178
+ return results, "\n".join(status_messages)
179
+
180
+ # Create the Gradio interface
181
+ with gr.Blocks(title="Children's Book Illustrator", theme="soft") as demo:
182
+ gr.Markdown("# πŸ“š Children's Book Illustrator")
183
+ gr.Markdown("Generate beautiful storybook images and automatically save them to your OCI storage")
184
+
185
+ with gr.Tab("Single Image Generation"):
186
+ with gr.Row():
187
+ with gr.Column():
188
+ prompt_input = gr.Textbox(
189
+ label="Scene Description",
190
+ placeholder="Describe a scene for your storybook...\nExample: A dragon reading a book under a magical tree",
191
+ lines=3
192
+ )
193
+ generate_btn = gr.Button("🎨 Generate Image", variant="primary")
194
+
195
+ with gr.Column():
196
+ image_output = gr.Image(label="Generated Image", height=400)
197
+ status_output = gr.Textbox(label="Status", interactive=False)
198
+
199
+ with gr.Tab("Batch Storybook Generation"):
200
+ with gr.Row():
201
+ with gr.Column():
202
+ scenes_input = gr.Textbox(
203
+ label="Story Scenes (One per line)",
204
+ placeholder="Enter each scene on a separate line...\nExample:\nA brave knight approaches the castle\nThe dragon guards a treasure chest\nChildren celebrating with the villagers",
205
+ lines=6
206
+ )
207
+ batch_generate_btn = gr.Button("πŸ“– Generate Storybook", variant="primary")
208
+
209
+ with gr.Column():
210
+ batch_status = gr.Textbox(label="Generation Status", interactive=False, lines=10)
211
 
212
+ # Gallery for batch results
213
+ gallery_output = gr.Gallery(
214
+ label="Generated Storybook Scenes",
215
+ columns=2,
216
+ height=600
217
+ )
218
+
219
+ with gr.Tab("About & Help"):
220
+ gr.Markdown("""
221
+ ## 🎯 How to Use
222
+
223
+ 1. **Single Image**: Enter a scene description and click "Generate Image"
224
+ 2. **Storybook**: Enter multiple scenes (one per line) for a complete story
225
+ 3. **Auto-Save**: All images are automatically saved to your OCI storage
226
+
227
+ ## πŸ“ Storage Location
228
+
229
+ Images are saved to: `storybook-generator/childrens-books/`
230
+
231
+ ## πŸ’‘ Prompt Tips
232
+
233
+ - Be descriptive: "A dragon reading a book under a magical tree"
234
+ - Add style: "watercolor style, cute, whimsical"
235
+ - Specify characters: "little mouse exploring a giant forest"
236
+
237
+ ## πŸ”§ Technical Details
238
+
239
+ - Uses Stable Diffusion v1.5
240
+ - Saves to OCI Object Storage
241
+ - Automatic image organization
242
+ """)
243
+
244
+ # Connect buttons to functions
245
+ generate_btn.click(
246
+ fn=generate_storybook_image,
247
+ inputs=prompt_input,
248
+ outputs=[image_output, status_output]
249
+ )
250
+
251
+ batch_generate_btn.click(
252
+ fn=batch_generate_storybook,
253
+ inputs=scenes_input,
254
+ outputs=[gallery_output, batch_status]
255
+ )
256
+
257
+ # For Hugging Face Spaces deployment
258
+ def get_app():
259
+ """Return the Gradio app for Hugging Face"""
260
+ return demo
261
 
262
+ # For local testing
263
  if __name__ == "__main__":
264
+ print("πŸš€ Starting Children's Book Illustrator...")
265
+ print(f"πŸ“¦ OCI Access: {'Direct' if DIRECT_OCI_ACCESS else 'API'}")
266
+ demo.launch(server_name="0.0.0.0", server_port=7860)
267
+ else:
268
+ # For Hugging Face deployment
269
+ print("πŸ“¦ Hugging Face Space detected")
270
+ print(f"πŸ”§ OCI Access: {'Direct' if DIRECT_OCI_ACCESS else 'API'}")