AkashKumarave commited on
Commit
9a71405
·
verified ·
1 Parent(s): 01117dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -70
app.py CHANGED
@@ -8,20 +8,13 @@ import os
8
  import time
9
  import jwt
10
  from pathlib import Path
11
- from typing import List, Optional
12
- from pydantic import BaseModel
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # ===== API CONFIGURATION =====
19
- # Load sensitive data from environment variables
20
- ACCESS_KEY_ID = os.getenv("ACCESS_KEY_ID", "AFyHfnQATghFdCMyAG3gRPbNY4TNKFGB")
21
- ACCESS_KEY_SECRET = os.getenv("ACCESS_KEY_SECRET", "TTepeLyBterLNM3brYPGmdndBnnyKJBA")
22
- API_BASE_URL = "https://api-singapore.klingai.com"
23
- CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image"
24
-
25
  # Initialize FastAPI app
26
  app = FastAPI(title="Kling AI Multi-Image Generator API")
27
 
@@ -34,6 +27,12 @@ app.add_middleware(
34
  allow_headers=["*"],
35
  )
36
 
 
 
 
 
 
 
37
  # ===== AUTHENTICATION =====
38
  def generate_jwt_token():
39
  """Generate JWT token for API authentication"""
@@ -51,19 +50,17 @@ def prepare_image_base64(image_content: bytes):
51
  return base64.b64encode(image_content).decode('utf-8')
52
  except Exception as e:
53
  logger.error(f"Image processing failed: {str(e)}")
54
- return None
55
 
56
  def validate_image(image_content: bytes):
57
  """Validate image meets API requirements"""
58
  try:
59
- # Check file size
60
  size_mb = len(image_content) / (1024 * 1024)
61
  if size_mb > 10:
62
- return False, "Image too large (max 10MB)"
63
- # Basic validation (add PIL-based dimension checks if needed)
64
  return True, ""
65
  except Exception as e:
66
- return False, f"Image validation error: {str(e)}"
67
 
68
  # ===== API FUNCTIONS =====
69
  def create_multi_image_task(subject_images: List[bytes], prompt: str):
@@ -73,7 +70,6 @@ def create_multi_image_task(subject_images: List[bytes], prompt: str):
73
  "Content-Type": "application/json"
74
  }
75
 
76
- # Prepare subject images list
77
  subject_image_list = []
78
  for img_content in subject_images:
79
  if img_content:
@@ -82,7 +78,7 @@ def create_multi_image_task(subject_images: List[bytes], prompt: str):
82
  subject_image_list.append({"subject_image": base64_img})
83
 
84
  if len(subject_image_list) < 2:
85
- return None, "At least 2 subject images required"
86
 
87
  payload = {
88
  "model_name": "kling-v2",
@@ -95,12 +91,12 @@ def create_multi_image_task(subject_images: List[bytes], prompt: str):
95
  try:
96
  response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
97
  response.raise_for_status()
98
- return response.json(), None
99
  except requests.exceptions.RequestException as e:
100
  logger.error(f"API request failed: {str(e)}")
101
  if hasattr(e, 'response') and e.response:
102
  logger.error(f"API response: {e.response.text}")
103
- return None, f"API Error: {str(e)}"
104
 
105
  def check_task_status(task_id: str):
106
  """Check task completion status"""
@@ -110,37 +106,26 @@ def check_task_status(task_id: str):
110
  try:
111
  response = requests.get(status_url, headers=headers)
112
  response.raise_for_status()
113
- return response.json(), None
114
  except requests.exceptions.RequestException as e:
115
- return None, f"Status check failed: {str(e)}"
116
 
117
  # ===== MAIN PROCESSING =====
118
  async def generate_image(subject_images: List[bytes], prompt: str):
119
  """Handle complete image generation workflow"""
120
- # Validate images
121
  for img_content in subject_images:
122
  if img_content:
123
- is_valid, error_msg = validate_image(img_content)
124
- if not is_valid:
125
- return None, error_msg
126
-
127
- # Create task
128
- task_response, error = create_multi_image_task(subject_images, prompt)
129
- if error:
130
- return None, error
131
 
 
132
  if task_response.get("code") != 0:
133
- return None, f"API error: {task_response.get('message', 'Unknown error')}"
134
 
135
  task_id = task_response["data"]["task_id"]
136
  logger.info(f"Task created: {task_id}")
137
 
138
- # Poll for results (max 10 minutes)
139
  for _ in range(60):
140
- task_data, error = check_task_status(task_id)
141
- if error:
142
- return None, error
143
-
144
  status = task_data["data"]["task_status"]
145
 
146
  if status == "succeed":
@@ -148,62 +133,52 @@ async def generate_image(subject_images: List[bytes], prompt: str):
148
  try:
149
  response = requests.get(image_url)
150
  response.raise_for_status()
151
- output_dir = Path("/tmp") # Hugging Face Spaces uses /tmp for ephemeral storage
152
  output_dir.mkdir(exist_ok=True)
153
  output_path = output_dir / f"kling_output_{task_id}.png"
154
  with open(output_path, "wb") as f:
155
  f.write(response.content)
156
- return str(output_path), None
157
  except Exception as e:
158
- return None, f"Failed to download result: {str(e)}"
159
 
160
  elif status in ("failed", "canceled"):
161
  error_msg = task_data["data"].get("task_status_msg", "Unknown error")
162
- return None, f"Task failed: {error_msg}"
163
 
164
  time.sleep(10)
165
 
166
- return None, "Task timed out after 10 minutes"
167
 
168
  # ===== API ENDPOINTS =====
169
- class GenerateImageRequest(BaseModel):
170
- prompt: str
171
-
172
  @app.post("/generate")
173
  async def generate_image_endpoint(
174
  prompt: str = Form(...),
175
  images: List[UploadFile] = File(...)
176
  ):
177
  """Endpoint to generate an image from multiple input images and a prompt"""
178
- if len(images) < 2:
179
- raise HTTPException(status_code=400, detail="At least 2 images are required")
180
- if len(images) > 4:
181
- raise HTTPException(status_code=400, detail="Maximum 4 images allowed")
182
-
183
- # Read image contents
184
- image_contents = []
185
- for image in images:
186
- content = await image.read()
187
- image_contents.append(content)
188
-
189
- # Generate image
190
- output_path, error = await generate_image(image_contents, prompt)
191
- if error:
192
- raise HTTPException(status_code=500, detail=error)
193
-
194
- # Return the generated image file
195
- return FileResponse(
196
- path=output_path,
197
- media_type="image/png",
198
- filename=f"kling_output_{Path(output_path).stem}.png"
199
- )
200
 
201
- # Health check endpoint
202
- @app.get("/health")
203
- async def health_check():
204
- return {"status": "healthy"}
205
 
206
  if __name__ == "__main__":
207
  import uvicorn
208
- # Hugging Face Spaces expects port 7860
209
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  import time
9
  import jwt
10
  from pathlib import Path
11
+ from typing import List
12
+ import io
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
18
  # Initialize FastAPI app
19
  app = FastAPI(title="Kling AI Multi-Image Generator API")
20
 
 
27
  allow_headers=["*"],
28
  )
29
 
30
+ # ===== API CONFIGURATION =====
31
+ ACCESS_KEY_ID = os.getenv("ACCESS_KEY_ID", "AFyHfnQATghFdCMyAG3gRPbNY4TNKFGB")
32
+ ACCESS_KEY_SECRET = os.getenv("ACCESS_KEY_SECRET", "TTepeLyBterLNM3brYPGmdndBnnyKJBA")
33
+ API_BASE_URL = "https://api-singapore.klingai.com"
34
+ CREATE_TASK_ENDPOINT = f"{API_BASE_URL}/v1/images/multi-image2image"
35
+
36
  # ===== AUTHENTICATION =====
37
  def generate_jwt_token():
38
  """Generate JWT token for API authentication"""
 
50
  return base64.b64encode(image_content).decode('utf-8')
51
  except Exception as e:
52
  logger.error(f"Image processing failed: {str(e)}")
53
+ raise HTTPException(status_code=500, detail=f"Image processing failed: {str(e)}")
54
 
55
  def validate_image(image_content: bytes):
56
  """Validate image meets API requirements"""
57
  try:
 
58
  size_mb = len(image_content) / (1024 * 1024)
59
  if size_mb > 10:
60
+ raise HTTPException(status_code=400, detail="Image too large (max 10MB)")
 
61
  return True, ""
62
  except Exception as e:
63
+ raise HTTPException(status_code=400, detail=f"Image validation error: {str(e)}")
64
 
65
  # ===== API FUNCTIONS =====
66
  def create_multi_image_task(subject_images: List[bytes], prompt: str):
 
70
  "Content-Type": "application/json"
71
  }
72
 
 
73
  subject_image_list = []
74
  for img_content in subject_images:
75
  if img_content:
 
78
  subject_image_list.append({"subject_image": base64_img})
79
 
80
  if len(subject_image_list) < 2:
81
+ raise HTTPException(status_code=400, detail="At least 2 subject images required")
82
 
83
  payload = {
84
  "model_name": "kling-v2",
 
91
  try:
92
  response = requests.post(CREATE_TASK_ENDPOINT, json=payload, headers=headers)
93
  response.raise_for_status()
94
+ return response.json()
95
  except requests.exceptions.RequestException as e:
96
  logger.error(f"API request failed: {str(e)}")
97
  if hasattr(e, 'response') and e.response:
98
  logger.error(f"API response: {e.response.text}")
99
+ raise HTTPException(status_code=500, detail=f"API Error: {str(e)}")
100
 
101
  def check_task_status(task_id: str):
102
  """Check task completion status"""
 
106
  try:
107
  response = requests.get(status_url, headers=headers)
108
  response.raise_for_status()
109
+ return response.json()
110
  except requests.exceptions.RequestException as e:
111
+ raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
112
 
113
  # ===== MAIN PROCESSING =====
114
  async def generate_image(subject_images: List[bytes], prompt: str):
115
  """Handle complete image generation workflow"""
 
116
  for img_content in subject_images:
117
  if img_content:
118
+ validate_image(img_content)
 
 
 
 
 
 
 
119
 
120
+ task_response = create_multi_image_task(subject_images, prompt)
121
  if task_response.get("code") != 0:
122
+ raise HTTPException(status_code=500, detail=f"API error: {task_response.get('message', 'Unknown error')}")
123
 
124
  task_id = task_response["data"]["task_id"]
125
  logger.info(f"Task created: {task_id}")
126
 
 
127
  for _ in range(60):
128
+ task_data = check_task_status(task_id)
 
 
 
129
  status = task_data["data"]["task_status"]
130
 
131
  if status == "succeed":
 
133
  try:
134
  response = requests.get(image_url)
135
  response.raise_for_status()
136
+ output_dir = Path("/tmp")
137
  output_dir.mkdir(exist_ok=True)
138
  output_path = output_dir / f"kling_output_{task_id}.png"
139
  with open(output_path, "wb") as f:
140
  f.write(response.content)
141
+ return output_path
142
  except Exception as e:
143
+ raise HTTPException(status_code=500, detail=f"Failed to download result: {str(e)}")
144
 
145
  elif status in ("failed", "canceled"):
146
  error_msg = task_data["data"].get("task_status_msg", "Unknown error")
147
+ raise HTTPException(status_code=500, detail=f"Task failed: {error_msg}")
148
 
149
  time.sleep(10)
150
 
151
+ raise HTTPException(status_code=500, detail="Task timed out after 10 minutes")
152
 
153
  # ===== API ENDPOINTS =====
 
 
 
154
  @app.post("/generate")
155
  async def generate_image_endpoint(
156
  prompt: str = Form(...),
157
  images: List[UploadFile] = File(...)
158
  ):
159
  """Endpoint to generate an image from multiple input images and a prompt"""
160
+ try:
161
+ if len(images) < 2:
162
+ raise HTTPException(status_code=400, detail="At least 2 images are required")
163
+ if len(images) > 4:
164
+ raise HTTPException(status_code=400, detail="Maximum 4 images allowed")
165
+
166
+ image_contents = [await image.read() for image in images]
167
+ output_path = await generate_image(image_contents, prompt)
168
+
169
+ return FileResponse(
170
+ path=output_path,
171
+ media_type="image/png",
172
+ filename=f"kling_output_{Path(output_path).stem}.png"
173
+ )
174
+ except Exception as e:
175
+ logger.error(f"Error in /generate: {str(e)}")
176
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
177
 
178
+ @app.get("/")
179
+ async def index():
180
+ return {"status": "Kling AI Multi-Image Generator API is running"}
 
181
 
182
  if __name__ == "__main__":
183
  import uvicorn
 
184
  uvicorn.run(app, host="0.0.0.0", port=7860)