LogicGoInfotechSpaces commited on
Commit
73008db
·
verified ·
1 Parent(s): 530b1d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -44
app.py CHANGED
@@ -11,16 +11,19 @@ from pydantic import BaseModel
11
  from pymongo import MongoClient
12
  import gridfs
13
  from bson.objectid import ObjectId
14
-
15
- import gradio as gr
16
- import transformers_gradio
17
- import spaces
18
- import torch
19
  from PIL import Image
 
 
 
 
20
 
21
  # ---------------------------------------------------------------------
22
- # MongoDB setup
23
  # ---------------------------------------------------------------------
 
 
 
 
24
  MONGODB_URI = "mongodb+srv://harilogicgo_db_user:jFhyDM4oA4dklUsp@api-logs.i7rqf9p.mongodb.net/?appName=API-LOGS"
25
  DB_NAME = "polaroid_db"
26
 
@@ -29,8 +32,11 @@ db = mongo[DB_NAME]
29
  fs = gridfs.GridFS(db)
30
  logs_collection = db["logs"]
31
 
 
 
 
32
  # ---------------------------------------------------------------------
33
- # FastAPI app setup
34
  # ---------------------------------------------------------------------
35
  app = FastAPI(title="Qwen Image Edit API")
36
  app.add_middleware(
@@ -47,7 +53,6 @@ app.add_middleware(
47
  BEARER_TOKEN = "logicgo@123"
48
 
49
  def verify_token(authorization: str = Header(None)):
50
- """Bearer token verification."""
51
  if not authorization or not authorization.startswith("Bearer "):
52
  raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
53
  token = authorization.split(" ")[1]
@@ -55,14 +60,6 @@ def verify_token(authorization: str = Header(None)):
55
  raise HTTPException(status_code=403, detail="Invalid bearer token")
56
  return True
57
 
58
- # ---------------------------------------------------------------------
59
- # Load Qwen Image Edit model through Gradio backend
60
- # ---------------------------------------------------------------------
61
- demo = gr.load(name="Qwen/Qwen-Image-Edit", src=transformers_gradio.registry)
62
- demo.fn = spaces.GPU()(demo.fn) # GPU acceleration
63
- for fn in demo.fns.values():
64
- fn.api_name = False # disable Gradio API names
65
-
66
  # ---------------------------------------------------------------------
67
  # Models
68
  # ---------------------------------------------------------------------
@@ -76,7 +73,6 @@ class HealthResponse(BaseModel):
76
  # ---------------------------------------------------------------------
77
  @app.get("/health", response_model=HealthResponse)
78
  def health():
79
- """Health check endpoint"""
80
  try:
81
  mongo.admin.command("ping")
82
  return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit")
@@ -90,24 +86,16 @@ async def generate(
90
  image2: Optional[UploadFile] = File(None),
91
  authorized: bool = Depends(verify_token)
92
  ):
93
- """
94
- Upload 1 or 2 images + prompt, generate edited image, store both input/output in GridFS.
95
- """
96
- # -----------------------------
97
- # 1. Read first image
98
- # -----------------------------
99
  try:
100
  img1_bytes = await image1.read()
101
- if not img1_bytes:
102
- raise HTTPException(status_code=400, detail="First image is empty")
103
  pil_img1 = Image.open(io.BytesIO(img1_bytes)).convert("RGB")
104
  except Exception as e:
105
  raise HTTPException(status_code=400, detail=f"Failed to read first image: {e}")
106
 
107
- # -----------------------------
108
- # 2. Read second image if provided
109
- # -----------------------------
110
  pil_img2 = None
 
111
  if image2:
112
  try:
113
  img2_bytes = await image2.read()
@@ -116,9 +104,7 @@ async def generate(
116
  except Exception as e:
117
  raise HTTPException(status_code=400, detail=f"Failed to read second image: {e}")
118
 
119
- # -----------------------------
120
- # 3. Save input images to GridFS
121
- # -----------------------------
122
  try:
123
  input1_id = fs.put(img1_bytes, filename=image1.filename, contentType=image1.content_type, role="input")
124
  input2_id = None
@@ -127,18 +113,25 @@ async def generate(
127
  except Exception as e:
128
  raise HTTPException(status_code=500, detail=f"Failed saving input images: {e}")
129
 
130
- # -----------------------------
131
- # 4. Run Qwen Image Edit model
132
- # -----------------------------
133
  try:
134
  images_to_pass = [pil_img1]
135
  if pil_img2:
136
  images_to_pass.append(pil_img2)
137
 
138
- # Run Gradio model fn directly
139
- output_image = demo.fn(images_to_pass, prompt)
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Convert output PIL to bytes
142
  out_buf = io.BytesIO()
143
  output_image.save(out_buf, format="PNG")
144
  out_bytes = out_buf.getvalue()
@@ -146,17 +139,21 @@ async def generate(
146
  traceback.print_exc()
147
  raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
148
 
149
- # -----------------------------
150
- # 5. Save output image
151
- # -----------------------------
152
  try:
153
- out_id = fs.put(out_bytes, filename=f"result_{input1_id}.png", contentType="image/png", prompt=prompt, role="output", input1_id=str(input1_id), input2_id=str(input2_id) if input2_id else None)
 
 
 
 
 
 
 
 
154
  except Exception as e:
155
  raise HTTPException(status_code=500, detail=f"Failed saving output image: {e}")
156
 
157
- # -----------------------------
158
- # 6. Log the request
159
- # -----------------------------
160
  try:
161
  logs_collection.insert_one({
162
  "timestamp": datetime.utcnow(),
@@ -170,6 +167,7 @@ async def generate(
170
 
171
  return JSONResponse({"output_image_id": str(out_id)})
172
 
 
173
  @app.get("/image/{image_id}")
174
  def get_image(image_id: str, download: Optional[bool] = False):
175
  """Retrieve image by GridFS ID"""
@@ -187,3 +185,10 @@ def get_image(image_id: str, download: Optional[bool] = False):
187
  headers["Content-Disposition"] = f'attachment; filename="{grid_out.filename}"'
188
 
189
  return StreamingResponse(iterfile(), media_type=grid_out.content_type or "application/octet-stream", headers=headers)
 
 
 
 
 
 
 
 
11
  from pymongo import MongoClient
12
  import gridfs
13
  from bson.objectid import ObjectId
 
 
 
 
 
14
  from PIL import Image
15
+ import torch
16
+
17
+ # Hugging Face Inference client
18
+ from huggingface_hub import InferenceClient
19
 
20
  # ---------------------------------------------------------------------
21
+ # Environment & MongoDB setup
22
  # ---------------------------------------------------------------------
23
+ HF_TOKEN = os.getenv("HF_TOKEN") # Must set in .env
24
+ if not HF_TOKEN:
25
+ raise RuntimeError("HF_TOKEN not set in environment variables")
26
+
27
  MONGODB_URI = "mongodb+srv://harilogicgo_db_user:jFhyDM4oA4dklUsp@api-logs.i7rqf9p.mongodb.net/?appName=API-LOGS"
28
  DB_NAME = "polaroid_db"
29
 
 
32
  fs = gridfs.GridFS(db)
33
  logs_collection = db["logs"]
34
 
35
+ # HF Inference client
36
+ hf_client = InferenceClient(token=HF_TOKEN)
37
+
38
  # ---------------------------------------------------------------------
39
+ # FastAPI app
40
  # ---------------------------------------------------------------------
41
  app = FastAPI(title="Qwen Image Edit API")
42
  app.add_middleware(
 
53
  BEARER_TOKEN = "logicgo@123"
54
 
55
  def verify_token(authorization: str = Header(None)):
 
56
  if not authorization or not authorization.startswith("Bearer "):
57
  raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
58
  token = authorization.split(" ")[1]
 
60
  raise HTTPException(status_code=403, detail="Invalid bearer token")
61
  return True
62
 
 
 
 
 
 
 
 
 
63
  # ---------------------------------------------------------------------
64
  # Models
65
  # ---------------------------------------------------------------------
 
73
  # ---------------------------------------------------------------------
74
  @app.get("/health", response_model=HealthResponse)
75
  def health():
 
76
  try:
77
  mongo.admin.command("ping")
78
  return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit")
 
86
  image2: Optional[UploadFile] = File(None),
87
  authorized: bool = Depends(verify_token)
88
  ):
89
+ """Upload 1 or 2 images + prompt and get edited image via HF Inference"""
90
+ # Read images
 
 
 
 
91
  try:
92
  img1_bytes = await image1.read()
 
 
93
  pil_img1 = Image.open(io.BytesIO(img1_bytes)).convert("RGB")
94
  except Exception as e:
95
  raise HTTPException(status_code=400, detail=f"Failed to read first image: {e}")
96
 
 
 
 
97
  pil_img2 = None
98
+ img2_bytes = None
99
  if image2:
100
  try:
101
  img2_bytes = await image2.read()
 
104
  except Exception as e:
105
  raise HTTPException(status_code=400, detail=f"Failed to read second image: {e}")
106
 
107
+ # Save input images to GridFS
 
 
108
  try:
109
  input1_id = fs.put(img1_bytes, filename=image1.filename, contentType=image1.content_type, role="input")
110
  input2_id = None
 
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=f"Failed saving input images: {e}")
115
 
116
+ # Run HF Inference
 
 
117
  try:
118
  images_to_pass = [pil_img1]
119
  if pil_img2:
120
  images_to_pass.append(pil_img2)
121
 
122
+ # The "Qwen/Qwen-Image-Edit" expects a list of PIL images and a prompt
123
+ pil_output = hf_client.image_to_image(
124
+ images=images_to_pass,
125
+ prompt=prompt,
126
+ model="Qwen/Qwen-Image-Edit"
127
+ )
128
+
129
+ # Handle list return
130
+ if isinstance(pil_output, list):
131
+ output_image = pil_output[0]
132
+ else:
133
+ output_image = pil_output
134
 
 
135
  out_buf = io.BytesIO()
136
  output_image.save(out_buf, format="PNG")
137
  out_bytes = out_buf.getvalue()
 
139
  traceback.print_exc()
140
  raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
141
 
142
+ # Save output image
 
 
143
  try:
144
+ out_id = fs.put(
145
+ out_bytes,
146
+ filename=f"result_{input1_id}.png",
147
+ contentType="image/png",
148
+ prompt=prompt,
149
+ role="output",
150
+ input1_id=str(input1_id),
151
+ input2_id=str(input2_id) if input2_id else None
152
+ )
153
  except Exception as e:
154
  raise HTTPException(status_code=500, detail=f"Failed saving output image: {e}")
155
 
156
+ # Log request
 
 
157
  try:
158
  logs_collection.insert_one({
159
  "timestamp": datetime.utcnow(),
 
167
 
168
  return JSONResponse({"output_image_id": str(out_id)})
169
 
170
+
171
  @app.get("/image/{image_id}")
172
  def get_image(image_id: str, download: Optional[bool] = False):
173
  """Retrieve image by GridFS ID"""
 
185
  headers["Content-Disposition"] = f'attachment; filename="{grid_out.filename}"'
186
 
187
  return StreamingResponse(iterfile(), media_type=grid_out.content_type or "application/octet-stream", headers=headers)
188
+
189
+ # ---------------------------------------------------------------------
190
+ # Run locally
191
+ # ---------------------------------------------------------------------
192
+ if __name__ == "__main__":
193
+ import uvicorn
194
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)