LiuZichen commited on
Commit
92f4b7f
·
verified ·
1 Parent(s): 08cd4bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -33
app.py CHANGED
@@ -1,14 +1,7 @@
1
  import subprocess
2
  import shlex
3
- # Install the custom component if needed
4
- subprocess.run(
5
- shlex.split(
6
- "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl"
7
- )
8
- )
9
  import sys
10
  import os
11
- import gradio as gr
12
  import tempfile
13
  import numpy as np
14
  import io
@@ -16,13 +9,21 @@ import base64
16
  import json
17
  import uvicorn
18
  import torch
 
 
 
 
 
 
 
 
 
 
19
  from fastapi import FastAPI, Request
20
  from fastapi.concurrency import run_in_threadpool
21
  from fastapi.middleware.cors import CORSMiddleware
22
  from gradio_client import Client, handle_file
23
  from gradio_magicquillv2 import MagicQuillV2
24
- from PIL import Image
25
-
26
 
27
  from util import (
28
  read_base64_image as read_base64_image_utils,
@@ -32,18 +33,13 @@ from util import (
32
 
33
  # --- Configuration ---
34
  # Set this to the URL of your backend Space (running app_backend.py)
35
- # Example: "https://huggingface.co/spaces/username/backend-space"
36
  BACKEND_URL = "LiuZichen/MagicQuillV2"
37
  SAM_URL = "LiuZichen/MagicQuillHelper"
38
 
39
- print(f"Connecting to backend at: {BACKEND_URL}")
40
-
41
- try:
42
- backend_client = Client(BACKEND_URL)
43
- except Exception as e:
44
- print(f"Failed to connect to backend: {e}")
45
- backend_client = None
46
 
 
 
47
  print(f"Connecting to SAM client at: {SAM_URL}")
48
  try:
49
  sam_client = Client(SAM_URL)
@@ -53,7 +49,32 @@ except Exception as e:
53
 
54
  # --- Helper Functions ---
55
 
56
- def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  merged_image = x['from_frontend']['img']
58
  total_mask = x['from_frontend']['total_mask']
59
  original_image = x['from_frontend']['original_image']
@@ -64,15 +85,17 @@ def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_
64
  add_prop_image = x['from_frontend']['add_prop_image']
65
  positive_prompt = x['from_backend']['prompt']
66
 
67
- if backend_client is None:
68
- print("Backend client not initialized")
69
- x["from_backend"]["generated_image"] = None
70
- return x
71
 
72
  try:
 
 
 
 
 
73
  # Call the backend API
74
- # The order of arguments must match app_backend.py input list
75
- res_base64 = backend_client.predict(
76
  merged_image, # merged_image
77
  total_mask, # total_mask
78
  original_image, # original_image
@@ -107,8 +130,8 @@ with gr.Blocks(title="MagicQuill V2") as demo:
107
  with gr.Row(elem_classes="row"):
108
  text = gr.Markdown(
109
  """
110
- # Welcome to MagicQuill V2! Give us a [GitHub star](https://github.com/zliucz/magicquillv2) if you are interested.
111
- Click the [link](https://magicquill.art/v2) to view our demo and tutorial. The paper is on [ArXiv](https://arxiv.org/abs/2512.03046) now.
112
  """)
113
 
114
  with gr.Row():
@@ -132,6 +155,8 @@ with gr.Blocks(title="MagicQuill V2") as demo:
132
 
133
  btn.click(
134
  generate_image_handler,
 
 
135
  inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg],
136
  outputs=ms
137
  )
@@ -150,18 +175,20 @@ app.add_middleware(
150
  def get_root_url(request: Request, route_path: str, root_path: str | None):
151
  return root_path
152
  gr.route_utils.get_root_url = get_root_url
153
- # gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo")
154
 
155
  @app.post("/magic_quill/generate_image")
156
  async def generate_image(request: Request):
157
  data = await request.json()
158
 
159
- if backend_client is None:
160
- return {'error': 'Backend client not connected'}
161
 
162
  try:
 
 
 
163
  res = await run_in_threadpool(
164
- backend_client.predict,
165
  data["merged_image"],
166
  data["total_mask"],
167
  data["original_image"],
@@ -214,7 +241,7 @@ async def segmentation(request: Request):
214
  if sam_client is None:
215
  return {"error": "sam client not initialized"}
216
 
217
- # Process coordinates and bboxes (copied from original app.py)
218
  pos_coordinates = None
219
  if coordinates_positive and len(coordinates_positive) > 0:
220
  pos_coordinates = []
@@ -345,8 +372,6 @@ async def segmentation(request: Request):
345
  else:
346
  seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
347
 
348
- print(seg_bbox)
349
-
350
  buffered = io.BytesIO()
351
  res_pil.save(buffered, format="PNG")
352
  image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
1
  import subprocess
2
  import shlex
 
 
 
 
 
 
3
  import sys
4
  import os
 
5
  import tempfile
6
  import numpy as np
7
  import io
 
9
  import json
10
  import uvicorn
11
  import torch
12
+ from PIL import Image
13
+
14
+ # Install the custom component if needed
15
+ subprocess.run(
16
+ shlex.split(
17
+ "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl"
18
+ )
19
+ )
20
+
21
+ import gradio as gr
22
  from fastapi import FastAPI, Request
23
  from fastapi.concurrency import run_in_threadpool
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from gradio_client import Client, handle_file
26
  from gradio_magicquillv2 import MagicQuillV2
 
 
27
 
28
  from util import (
29
  read_base64_image as read_base64_image_utils,
 
33
 
34
  # --- Configuration ---
35
  # Set this to the URL of your backend Space (running app_backend.py)
 
36
  BACKEND_URL = "LiuZichen/MagicQuillV2"
37
  SAM_URL = "LiuZichen/MagicQuillHelper"
38
 
39
+ print(f"Target Backend URL: {BACKEND_URL}")
 
 
 
 
 
 
40
 
41
+ # We still initialize SAM client globally as it might not require ZeroGPU quotas
42
+ # or is a helper CPU space.
43
  print(f"Connecting to SAM client at: {SAM_URL}")
44
  try:
45
  sam_client = Client(SAM_URL)
 
49
 
50
  # --- Helper Functions ---
51
 
52
+ def get_zerogpu_headers(request_headers):
53
+ """
54
+ Extracts ZeroGPU specific headers from the incoming request headers.
55
+ These are required to forward the user's quota token to the backend.
56
+ """
57
+ headers = {}
58
+ if request_headers:
59
+ # These are the headers HF injects for ZeroGPU authentication and tracking
60
+ target_headers = [
61
+ "x-ip-token",
62
+ "x-zerogpu-token",
63
+ "x-zerogpu-uuid",
64
+ "authorization",
65
+ "cookie"
66
+ ]
67
+ for h in target_headers:
68
+ val = request_headers.get(h)
69
+ if val:
70
+ headers[h] = val
71
+ return headers
72
+
73
+ def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg, request: gr.Request):
74
+ """
75
+ Handler for the Gradio UI.
76
+ Note the 'request: gr.Request' argument - Gradio automatically injects this.
77
+ """
78
  merged_image = x['from_frontend']['img']
79
  total_mask = x['from_frontend']['total_mask']
80
  original_image = x['from_frontend']['original_image']
 
85
  add_prop_image = x['from_frontend']['add_prop_image']
86
  positive_prompt = x['from_backend']['prompt']
87
 
88
+ # 1. Extract headers from the current user's request
89
+ forward_headers = get_zerogpu_headers(request.headers)
 
 
90
 
91
  try:
92
+ # 2. Instantiate a client specifically for this request with the forwarded headers.
93
+ # This ensures the backend sees the 'x-zerogpu-token' of the user, not the server.
94
+ # gradio_client caches schemas, so re-init is relatively cheap but necessary for headers.
95
+ client = Client(BACKEND_URL, headers=forward_headers)
96
+
97
  # Call the backend API
98
+ res_base64 = client.predict(
 
99
  merged_image, # merged_image
100
  total_mask, # total_mask
101
  original_image, # original_image
 
130
  with gr.Row(elem_classes="row"):
131
  text = gr.Markdown(
132
  """
133
+ # Welcome to MagicQuill V2! Give us a [GitHub star] if you are interested.
134
+ Click the [link] to view our demo and tutorial. The paper is on [ArXiv] now.
135
  """)
136
 
137
  with gr.Row():
 
155
 
156
  btn.click(
157
  generate_image_handler,
158
+ # Note: We do NOT need to explicitly add 'request' to inputs here.
159
+ # Gradio handles type hinting for gr.Request automatically.
160
  inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg],
161
  outputs=ms
162
  )
 
175
  def get_root_url(request: Request, route_path: str, root_path: str | None):
176
  return root_path
177
  gr.route_utils.get_root_url = get_root_url
 
178
 
179
  @app.post("/magic_quill/generate_image")
180
  async def generate_image(request: Request):
181
  data = await request.json()
182
 
183
+ # 1. Extract headers from the FastAPI request object
184
+ forward_headers = get_zerogpu_headers(request.headers)
185
 
186
  try:
187
+ # 2. Instantiate client with headers
188
+ client = Client(BACKEND_URL, headers=forward_headers)
189
+
190
  res = await run_in_threadpool(
191
+ client.predict,
192
  data["merged_image"],
193
  data["total_mask"],
194
  data["original_image"],
 
241
  if sam_client is None:
242
  return {"error": "sam client not initialized"}
243
 
244
+ # Process coordinates and bboxes
245
  pos_coordinates = None
246
  if coordinates_positive and len(coordinates_positive) > 0:
247
  pos_coordinates = []
 
372
  else:
373
  seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
374
 
 
 
375
  buffered = io.BytesIO()
376
  res_pil.save(buffered, format="PNG")
377
  image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")