LogicGoInfotechSpaces commited on
Commit
20fc4dd
·
verified ·
1 Parent(s): 4dc04ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -108
app.py CHANGED
@@ -14,25 +14,21 @@ import gridfs
14
  from bson.objectid import ObjectId
15
  from PIL import Image
16
  from fastapi.concurrency import run_in_threadpool
 
17
  import shutil
18
- import firebase_admin
19
- from firebase_admin import credentials, auth
20
  from PIL import Image
21
  from huggingface_hub import InferenceClient
 
 
22
 
23
  # ---------------------------------------------------------------------
24
- # Load Firebase Config from env (stringified JSON)
25
  # ---------------------------------------------------------------------
26
- firebase_config_json = os.getenv("firebase_config")
27
- if not firebase_config_json:
28
- raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'")
 
29
 
30
- try:
31
- firebase_creds_dict = json.loads(firebase_config_json)
32
- cred = credentials.Certificate(firebase_creds_dict)
33
- firebase_admin.initialize_app(cred)
34
- except Exception as e:
35
- raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}")
36
 
37
  # ---------------------------------------------------------------------
38
  # Hugging Face setup
@@ -43,6 +39,14 @@ if not HF_TOKEN:
43
 
44
  hf_client = InferenceClient(token=HF_TOKEN)
45
 
 
 
 
 
 
 
 
 
46
  # ---------------------------------------------------------------------
47
  # MongoDB setup
48
  # ---------------------------------------------------------------------
@@ -57,7 +61,7 @@ logs_collection = db["logs"]
57
  # ---------------------------------------------------------------------
58
  # FastAPI app setup
59
  # ---------------------------------------------------------------------
60
- app = FastAPI(title="Qwen Image Edit API with Firebase Auth")
61
  app.add_middleware(
62
  CORSMiddleware,
63
  allow_origins=["*"],
@@ -66,22 +70,7 @@ app.add_middleware(
66
  allow_headers=["*"],
67
  )
68
 
69
- # ---------------------------------------------------------------------
70
- # Auth dependency
71
- # ---------------------------------------------------------------------
72
- async def verify_firebase_token(request: Request):
73
- """Middleware-like dependency to verify Firebase JWT from Authorization header."""
74
- auth_header = request.headers.get("Authorization")
75
- if not auth_header or not auth_header.startswith("Bearer "):
76
- raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
77
 
78
- id_token = auth_header.split("Bearer ")[1]
79
- try:
80
- decoded_token = auth.verify_id_token(id_token)
81
- request.state.user = decoded_token
82
- return decoded_token
83
- except Exception as e:
84
- raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}")
85
 
86
  # ---------------------------------------------------------------------
87
  # Models
@@ -90,6 +79,7 @@ class HealthResponse(BaseModel):
90
  status: str
91
  db: str
92
  model: str
 
93
 
94
  # --------------------- UTILS ---------------------
95
  def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image:
@@ -171,16 +161,19 @@ async def root():
171
  def health():
172
  """Public health check"""
173
  mongo.admin.command("ping")
174
- return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit")
 
175
 
176
  @app.post("/generate")
177
  async def generate(
178
  prompt: str = Form(...),
 
179
  image1: UploadFile = File(...),
180
  image2: Optional[UploadFile] = File(None),
181
  user_id: Optional[str] = Form(None),
182
  category_id: Optional[str] = Form(None),
183
- user=Depends(verify_firebase_token)
 
184
  ):
185
  start_time = time.time()
186
 
@@ -235,28 +228,21 @@ async def generate(
235
  try:
236
  admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI"))
237
  admin_db = admin_client["adminPanel"]
238
-
239
  categories_col = admin_db.categories
240
  media_clicks_col = admin_db.media_clicks
241
-
242
- # Validate user_oid & category_oid
243
  user_oid = ObjectId(user_id)
244
  category_oid = ObjectId(category_id)
245
-
246
- # Check category exists
247
  category_doc = categories_col.find_one({"_id": category_oid})
248
  if not category_doc:
249
  raise HTTPException(400, f"Invalid category_id: {category_id}")
250
-
251
  now = datetime.utcnow()
252
-
253
- # Normalize dates (UTC midnight)
254
  today_date = datetime(now.year, now.month, now.day)
255
- yesterday_date = today_date - timedelta(days=1)
256
-
257
- # --------------------------------------------------
258
- # AI EDIT USAGE TRACKING (GLOBAL PER USER)
259
- # --------------------------------------------------
260
  media_clicks_col.update_one(
261
  {"userId": user_oid},
262
  {
@@ -274,49 +260,27 @@ async def generate(
274
  },
275
  upsert=True
276
  )
277
-
278
- # --------------------------------------------------
279
- # DAILY COUNT LOGIC
280
- # --------------------------------------------------
281
- now = datetime.utcnow()
282
- today_date = datetime(now.year, now.month, now.day)
283
-
284
- doc = media_clicks_col.find_one(
285
- {"userId": user_oid},
286
- {"ai_edit_daily_count": 1}
287
- )
288
-
289
  daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
290
-
291
- # Build UNIQUE date -> count map
292
  daily_map = {}
293
  for entry in daily_entries:
294
  d = entry["date"]
295
  d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d
296
- daily_map[d] = entry["count"] # overwrite = no duplicates
297
-
298
- # Find last known date
299
  last_date = max(daily_map.keys()) if daily_map else today_date
300
-
301
- # Fill ALL missing days with 0
302
  next_day = last_date + timedelta(days=1)
303
  while next_day < today_date:
304
  daily_map.setdefault(next_day, 0)
305
  next_day += timedelta(days=1)
306
-
307
- # Mark today as used (binary)
308
  daily_map[today_date] = 1
309
-
310
- # Rebuild list (OLD NEW)
311
- final_daily_entries = [
312
- {"date": d, "count": daily_map[d]}
313
- for d in sorted(daily_map.keys())
314
- ]
315
-
316
- # Keep last 32 days only
317
  final_daily_entries = final_daily_entries[-32:]
318
-
319
- # ATOMIC REPLACE (NO PUSH)
320
  media_clicks_col.update_one(
321
  {"userId": user_oid},
322
  {
@@ -328,10 +292,6 @@ async def generate(
328
  }
329
  )
330
 
331
-
332
- # --------------------------------------------------
333
- # CATEGORY CLICK LOGIC
334
- # --------------------------------------------------
335
  update_res = media_clicks_col.update_one(
336
  {"userId": user_oid, "categories.categoryId": category_oid},
337
  {
@@ -339,13 +299,10 @@ async def generate(
339
  "updatedAt": now,
340
  "categories.$.lastClickedAt": now
341
  },
342
- "$inc": {
343
- "categories.$.click_count": 1
344
- }
345
  }
346
  )
347
-
348
- # If category does not exist → push new
349
  if update_res.matched_count == 0:
350
  media_clicks_col.update_one(
351
  {"userId": user_oid},
@@ -361,18 +318,55 @@ async def generate(
361
  },
362
  upsert=True
363
  )
364
-
365
  except Exception as e:
366
  print("CATEGORY_LOG_ERROR:", e)
 
367
  # -------------------------
368
- # 4. HF INFERENCE
369
  # -------------------------
 
 
 
 
 
 
 
370
  try:
371
- pil_output = hf_client.image_to_image(
372
- image=combined_img,
373
- prompt=prompt,
374
- model="Qwen/Qwen-Image-Edit"
375
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  except Exception as e:
377
  response_time_ms = round((time.time() - start_time) * 1000)
378
  logs_collection.insert_one({
@@ -381,11 +375,11 @@ async def generate(
381
  "input1_id": str(input1_id),
382
  "input2_id": str(input2_id) if input2_id else None,
383
  "prompt": prompt,
384
- "user_email": user.get("email"),
385
  "error": str(e),
386
  "response_time_ms": response_time_ms
387
  })
388
- raise HTTPException(500, f"Inference failed: {e}")
389
 
390
  # -------------------------
391
  # 5. SAVE OUTPUT IMAGE
@@ -403,17 +397,13 @@ async def generate(
403
  "prompt": prompt,
404
  "input1_id": str(input1_id),
405
  "input2_id": str(input2_id) if input2_id else None,
406
- "user_email": user.get("email"),
 
407
  }
408
  )
409
- # -------------------------
410
- # 5b. SAVE COMPRESSED IMAGE
411
- # -------------------------
412
- compressed_bytes = compress_pil_image_to_2mb(
413
- pil_output,
414
- max_dim=1280
415
- )
416
-
417
  compressed_id = fs.put(
418
  compressed_bytes,
419
  filename=f"result_{input1_id}_compressed.jpg",
@@ -422,11 +412,11 @@ async def generate(
422
  "role": "output_compressed",
423
  "original_output_id": str(out_id),
424
  "prompt": prompt,
425
- "user_email": user.get("email")
 
426
  }
427
  )
428
 
429
-
430
  response_time_ms = round((time.time() - start_time) * 1000)
431
 
432
  # -------------------------
@@ -439,18 +429,18 @@ async def generate(
439
  "input2_id": str(input2_id) if input2_id else None,
440
  "output_id": str(out_id),
441
  "prompt": prompt,
442
- "user_email": user.get("email"),
 
443
  "response_time_ms": response_time_ms
444
  })
445
 
446
  return JSONResponse({
447
  "output_image_id": str(out_id),
448
- "user": user.get("email"),
449
  "response_time_ms": response_time_ms,
450
- "Compressed_Image_URL": (
451
- f"https://logicgoinfotechspaces-polaroidimage.hf.space/image/{compressed_id}"
452
- )
453
  })
 
454
  ####################---------------------------------------------------------------------------------------------------###
455
  ###-----OLD CODE--------------####
456
  # @app.post("/generate")
 
14
  from bson.objectid import ObjectId
15
  from PIL import Image
16
  from fastapi.concurrency import run_in_threadpool
17
+ import logging
18
  import shutil
 
 
19
  from PIL import Image
20
  from huggingface_hub import InferenceClient
21
+ from dotenv import load_dotenv
22
+ import google.generativeai as genai
23
 
24
  # ---------------------------------------------------------------------
25
+ # Load Firebase Config from env (stringified JSON) or from a credentials file
26
  # ---------------------------------------------------------------------
27
+ # Load .env file if present (allows using a .env in the project root)
28
+ load_dotenv()
29
+
30
+ # Firebase authentication removed — no initialization required
31
 
 
 
 
 
 
 
32
 
33
  # ---------------------------------------------------------------------
34
  # Hugging Face setup
 
39
 
40
  hf_client = InferenceClient(token=HF_TOKEN)
41
 
42
+ # Model configuration: defaults can be overridden via environment variables
43
+ QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen-Image-Edit")
44
+ GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-image")
45
+ MODEL_MAP = {
46
+ "qwen": QWEN_MODEL,
47
+ "gemini": GEMINI_MODEL
48
+ }
49
+
50
  # ---------------------------------------------------------------------
51
  # MongoDB setup
52
  # ---------------------------------------------------------------------
 
61
  # ---------------------------------------------------------------------
62
  # FastAPI app setup
63
  # ---------------------------------------------------------------------
64
+ app = FastAPI(title="Qwen Image Edit API")
65
  app.add_middleware(
66
  CORSMiddleware,
67
  allow_origins=["*"],
 
70
  allow_headers=["*"],
71
  )
72
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
74
 
75
  # ---------------------------------------------------------------------
76
  # Models
 
79
  status: str
80
  db: str
81
  model: str
82
+ available_models: dict
83
 
84
  # --------------------- UTILS ---------------------
85
  def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image:
 
161
  def health():
162
  """Public health check"""
163
  mongo.admin.command("ping")
164
+ # Return default model (qwen) and available models mapping
165
+ return HealthResponse(status="ok", db=db.name, model=MODEL_MAP.get("qwen"), available_models=MODEL_MAP)
166
 
167
  @app.post("/generate")
168
  async def generate(
169
  prompt: str = Form(...),
170
+ model: str = Form("qwen"),
171
  image1: UploadFile = File(...),
172
  image2: Optional[UploadFile] = File(None),
173
  user_id: Optional[str] = Form(None),
174
  category_id: Optional[str] = Form(None),
175
+ # Firebase auth disabled — `user` can be provided by client if needed, otherwise it will be None
176
+ user: Optional[dict] = None
177
  ):
178
  start_time = time.time()
179
 
 
228
  try:
229
  admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI"))
230
  admin_db = admin_client["adminPanel"]
231
+
232
  categories_col = admin_db.categories
233
  media_clicks_col = admin_db.media_clicks
234
+
 
235
  user_oid = ObjectId(user_id)
236
  category_oid = ObjectId(category_id)
237
+
 
238
  category_doc = categories_col.find_one({"_id": category_oid})
239
  if not category_doc:
240
  raise HTTPException(400, f"Invalid category_id: {category_id}")
241
+
242
  now = datetime.utcnow()
 
 
243
  today_date = datetime(now.year, now.month, now.day)
244
+
245
+ # Update daily AI edit usage
 
 
 
246
  media_clicks_col.update_one(
247
  {"userId": user_oid},
248
  {
 
260
  },
261
  upsert=True
262
  )
263
+
264
+ # Update category click counts
265
+ doc = media_clicks_col.find_one({"userId": user_oid}, {"ai_edit_daily_count": 1})
 
 
 
 
 
 
 
 
 
266
  daily_entries = doc.get("ai_edit_daily_count", []) if doc else []
267
+
 
268
  daily_map = {}
269
  for entry in daily_entries:
270
  d = entry["date"]
271
  d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d
272
+ daily_map[d] = entry["count"]
273
+
 
274
  last_date = max(daily_map.keys()) if daily_map else today_date
 
 
275
  next_day = last_date + timedelta(days=1)
276
  while next_day < today_date:
277
  daily_map.setdefault(next_day, 0)
278
  next_day += timedelta(days=1)
 
 
279
  daily_map[today_date] = 1
280
+
281
+ final_daily_entries = [{"date": d, "count": daily_map[d]} for d in sorted(daily_map.keys())]
 
 
 
 
 
 
282
  final_daily_entries = final_daily_entries[-32:]
283
+
 
284
  media_clicks_col.update_one(
285
  {"userId": user_oid},
286
  {
 
292
  }
293
  )
294
 
 
 
 
 
295
  update_res = media_clicks_col.update_one(
296
  {"userId": user_oid, "categories.categoryId": category_oid},
297
  {
 
299
  "updatedAt": now,
300
  "categories.$.lastClickedAt": now
301
  },
302
+ "$inc": {"categories.$.click_count": 1}
 
 
303
  }
304
  )
305
+
 
306
  if update_res.matched_count == 0:
307
  media_clicks_col.update_one(
308
  {"userId": user_oid},
 
318
  },
319
  upsert=True
320
  )
321
+
322
  except Exception as e:
323
  print("CATEGORY_LOG_ERROR:", e)
324
+
325
  # -------------------------
326
+ # 4. INFERENCE (Gemini direct + HF fallback)
327
  # -------------------------
328
+ used_model = None
329
+ pil_output = None
330
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
331
+
332
+ # determine user email safely (Firebase auth is disabled)
333
+ user_email = user.get("email") if user else None
334
+
335
  try:
336
+ chosen_key = model.lower()
337
+ chosen_model = MODEL_MAP.get(chosen_key, model)
338
+ fallback_model = MODEL_MAP.get("qwen") if chosen_key == "gemini" else MODEL_MAP.get("gemini")
339
+ used_model = chosen_model
340
+
341
+ try:
342
+ pil_output = hf_client.image_to_image(
343
+ image=combined_img,
344
+ prompt=prompt,
345
+ model=chosen_model
346
+ )
347
+ except Exception as e_primary:
348
+ logging.warning(f"Primary model {chosen_model} failed: {e_primary}. Trying fallback {fallback_model}")
349
+ try:
350
+ pil_output = hf_client.image_to_image(
351
+ image=combined_img,
352
+ prompt=prompt,
353
+ model=fallback_model
354
+ )
355
+ used_model = fallback_model
356
+ except Exception as e_fallback:
357
+ response_time_ms = round((time.time() - start_time) * 1000)
358
+ logs_collection.insert_one({
359
+ "timestamp": datetime.utcnow(),
360
+ "status": "failure",
361
+ "input1_id": str(input1_id),
362
+ "input2_id": str(input2_id) if input2_id else None,
363
+ "prompt": prompt,
364
+ "user_email": user_email,
365
+ "error": f"Primary error: {e_primary}; Fallback error: {e_fallback}",
366
+ "response_time_ms": response_time_ms
367
+ })
368
+ raise HTTPException(500, detail=f"Inference failed: primary error: {e_primary}; fallback error: {e_fallback}")
369
+
370
  except Exception as e:
371
  response_time_ms = round((time.time() - start_time) * 1000)
372
  logs_collection.insert_one({
 
375
  "input1_id": str(input1_id),
376
  "input2_id": str(input2_id) if input2_id else None,
377
  "prompt": prompt,
378
+ "user_email": user_email,
379
  "error": str(e),
380
  "response_time_ms": response_time_ms
381
  })
382
+ raise HTTPException(500, detail=f"Inference failed: {e}")
383
 
384
  # -------------------------
385
  # 5. SAVE OUTPUT IMAGE
 
397
  "prompt": prompt,
398
  "input1_id": str(input1_id),
399
  "input2_id": str(input2_id) if input2_id else None,
400
+ "user_email": user_email,
401
+ "used_model": used_model,
402
  }
403
  )
404
+
405
+ compressed_bytes = compress_pil_image_to_2mb(pil_output, max_dim=1280)
406
+
 
 
 
 
 
407
  compressed_id = fs.put(
408
  compressed_bytes,
409
  filename=f"result_{input1_id}_compressed.jpg",
 
412
  "role": "output_compressed",
413
  "original_output_id": str(out_id),
414
  "prompt": prompt,
415
+ "used_model": used_model,
416
+ "user_email": user_email
417
  }
418
  )
419
 
 
420
  response_time_ms = round((time.time() - start_time) * 1000)
421
 
422
  # -------------------------
 
429
  "input2_id": str(input2_id) if input2_id else None,
430
  "output_id": str(out_id),
431
  "prompt": prompt,
432
+ "user_email": user_email,
433
+ "used_model": used_model,
434
  "response_time_ms": response_time_ms
435
  })
436
 
437
  return JSONResponse({
438
  "output_image_id": str(out_id),
439
+ "user": user_email,
440
  "response_time_ms": response_time_ms,
441
+ "Compressed_Image_URL": f"https://logicgoinfotechspaces-polaroidimage.hf.space/image/{compressed_id}"
 
 
442
  })
443
+
444
  ####################---------------------------------------------------------------------------------------------------###
445
  ###-----OLD CODE--------------####
446
  # @app.post("/generate")