yukee1992 commited on
Commit
dd35031
·
verified ·
1 Parent(s): 606ff17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -2
app.py CHANGED
@@ -8,6 +8,14 @@ import io
8
  import base64
9
  from PIL import Image
10
  import time
 
 
 
 
 
 
 
 
11
 
12
  # Initialize FastAPI
13
  app = FastAPI(title="Children's Book Illustrator API")
@@ -53,17 +61,87 @@ except Exception as e:
53
  pipe = pipe.to(device)
54
  print(f"Fell back to {model_id}")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Request model
57
  class GenerateRequest(BaseModel):
58
  prompt: str
59
  width: int = 512
60
  height: int = 512
61
  steps: int = 25
 
62
 
63
  # Health check endpoint
64
  @app.get("/")
65
  async def health_check():
66
- return {"status": "healthy", "model": model_id}
 
67
 
68
  # Main API endpoint
69
  @app.post("/generate")
@@ -91,6 +169,11 @@ async def generate_image(request: GenerateRequest):
91
 
92
  print("Image generated successfully!")
93
 
 
 
 
 
 
94
  # Convert to base64 for API response
95
  buffered = io.BytesIO()
96
  image.save(buffered, format="PNG")
@@ -99,7 +182,9 @@ async def generate_image(request: GenerateRequest):
99
  return {
100
  "status": "success",
101
  "image": f"data:image/png;base64,{img_base64}",
102
- "prompt": request.prompt
 
 
103
  }
104
 
105
  except Exception as e:
 
8
  import base64
9
  from PIL import Image
10
  import time
11
+ from datetime import datetime
12
+ import os
13
+
14
+ # Google Drive imports
15
+ from google.oauth2 import service_account
16
+ from googleapiclient.discovery import build
17
+ from googleapiclient.http import MediaIoBaseUpload
18
+ import json
19
 
20
  # Initialize FastAPI
21
  app = FastAPI(title="Children's Book Illustrator API")
 
61
  pipe = pipe.to(device)
62
  print(f"Fell back to {model_id}")
63
 
64
+ # Google Drive Setup
65
+ def setup_google_drive():
66
+ """Initialize Google Drive service"""
67
+ try:
68
+ # Get service account credentials from environment variable
69
+ credentials_json = os.getenv('GOOGLE_SERVICE_ACCOUNT_JSON')
70
+ if not credentials_json:
71
+ print("Google Drive: No service account credentials found")
72
+ return None
73
+
74
+ # Parse the JSON credentials
75
+ service_account_info = json.loads(credentials_json)
76
+ credentials = service_account.Credentials.from_service_account_info(
77
+ service_account_info,
78
+ scopes=['https://www.googleapis.com/auth/drive.file']
79
+ )
80
+
81
+ # Build the Drive service
82
+ drive_service = build('drive', 'v3', credentials=credentials)
83
+ print("Google Drive service initialized successfully")
84
+ return drive_service
85
+
86
+ except Exception as e:
87
+ print(f"Google Drive setup failed: {e}")
88
+ return None
89
+
90
+ # Initialize Google Drive service
91
+ drive_service = setup_google_drive()
92
+
93
+ def save_to_google_drive(image, prompt):
94
+ """Save image to Google Drive"""
95
+ if not drive_service:
96
+ print("Google Drive not configured, skipping save")
97
+ return None
98
+
99
+ try:
100
+ # Create a filename with timestamp
101
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
102
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip()
103
+ filename = f"storybook_{timestamp}_{safe_prompt}.png"
104
+
105
+ # Convert image to bytes
106
+ img_bytes = io.BytesIO()
107
+ image.save(img_bytes, format='PNG')
108
+ img_bytes.seek(0)
109
+
110
+ # Create file metadata
111
+ file_metadata = {
112
+ 'name': filename,
113
+ 'mimeType': 'image/png',
114
+ 'parents': ['root'] # Save to root folder, or specify folder ID
115
+ }
116
+
117
+ # Upload to Google Drive
118
+ media = MediaIoBaseUpload(img_bytes, mimetype='image/png', resumable=True)
119
+ file = drive_service.files().create(
120
+ body=file_metadata,
121
+ media_body=media,
122
+ fields='id, webViewLink'
123
+ ).execute()
124
+
125
+ print(f"Image saved to Google Drive: {file.get('webViewLink')}")
126
+ return file.get('webViewLink')
127
+
128
+ except Exception as e:
129
+ print(f"Failed to save to Google Drive: {e}")
130
+ return None
131
+
132
  # Request model
133
  class GenerateRequest(BaseModel):
134
  prompt: str
135
  width: int = 512
136
  height: int = 512
137
  steps: int = 25
138
+ save_to_drive: bool = True # New option to control saving
139
 
140
  # Health check endpoint
141
  @app.get("/")
142
  async def health_check():
143
+ drive_status = "connected" if drive_service else "disconnected"
144
+ return {"status": "healthy", "model": model_id, "google_drive": drive_status}
145
 
146
  # Main API endpoint
147
  @app.post("/generate")
 
169
 
170
  print("Image generated successfully!")
171
 
172
+ # Save to Google Drive if enabled
173
+ drive_link = None
174
+ if request.save_to_drive and drive_service:
175
+ drive_link = save_to_google_drive(image, request.prompt)
176
+
177
  # Convert to base64 for API response
178
  buffered = io.BytesIO()
179
  image.save(buffered, format="PNG")
 
182
  return {
183
  "status": "success",
184
  "image": f"data:image/png;base64,{img_base64}",
185
+ "prompt": request.prompt,
186
+ "google_drive_link": drive_link,
187
+ "saved_to_drive": drive_link is not None
188
  }
189
 
190
  except Exception as e: