sharathmajjigi commited on
Commit
61ba6a6
·
1 Parent(s): c94a322

Add custom /v1/ground endpoint specifically for Agent-S

Browse files
Files changed (2) hide show
  1. app.py +76 -12
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModel
3
  import torch
@@ -6,6 +7,10 @@ import io
6
  import base64
7
  import json
8
  import numpy as np
 
 
 
 
9
 
10
  # UI-TARS model name
11
  model_name = "ByteDance-Seed/UI-TARS-1.5-7b"
@@ -37,7 +42,7 @@ def load_model():
37
 
38
  except Exception as e:
39
  print(f"❌ Error loading UI-TARS: {str(e)}")
40
- print("�� Attempting to load with fallback configuration...")
41
 
42
  try:
43
  # Fallback: Load without device_map
@@ -106,23 +111,82 @@ def process_grounding(image, prompt):
106
  "status": "failed"
107
  }
108
 
109
- # Create Gradio interface with API enabled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  iface = gr.Interface(
111
  fn=process_grounding,
112
  inputs=[
113
  gr.Image(type="pil", label="Upload Screenshot"),
114
  gr.Textbox(label="Prompt/Goal", placeholder="What do you want to do?")
115
  ],
116
- outputs=gr.JSON(label="Grounding Results"), # Changed to JSON output
117
  title="UI-TARS Grounding Model",
118
- description="Upload a screenshot and describe your goal to get grounding results from UI-TARS",
119
- api_name="ground" # This creates /api/ground endpoint
120
  )
121
 
122
- # Launch with API enabled
123
- iface.launch(
124
- server_name="0.0.0.0",
125
- server_port=7860,
126
- share=False,
127
- show_api=True # This enables the API endpoints
128
- )
 
1
+ # app.py - Add Custom Endpoint for Agent-S
2
  import gradio as gr
3
  from transformers import AutoProcessor, AutoModel
4
  import torch
 
7
  import base64
8
  import json
9
  import numpy as np
10
+ from fastapi import FastAPI, Request
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import JSONResponse
13
+ import uvicorn
14
 
15
  # UI-TARS model name
16
  model_name = "ByteDance-Seed/UI-TARS-1.5-7b"
 
42
 
43
  except Exception as e:
44
  print(f"❌ Error loading UI-TARS: {str(e)}")
45
+ print(" Attempting to load with fallback configuration...")
46
 
47
  try:
48
  # Fallback: Load without device_map
 
111
  "status": "failed"
112
  }
113
 
114
+ # Create FastAPI app
115
+ app = FastAPI(title="UI-TARS Grounding API")
116
+
117
+ # Add CORS middleware
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=["*"],
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
+
126
+ # Custom endpoint specifically for Agent-S
127
+ @app.post("/v1/ground")
128
+ async def agent_s_grounding(request: Request):
129
+ """
130
+ Custom endpoint specifically designed for Agent-S
131
+ """
132
+ try:
133
+ # Parse the request body
134
+ body = await request.json()
135
+
136
+ # Agent-S typically sends data in this format
137
+ if "data" in body and len(body["data"]) >= 2:
138
+ image = body["data"][0] # First element is image
139
+ prompt = body["data"][1] # Second element is prompt
140
+ elif "image" in body and "prompt" in body:
141
+ image = body["image"]
142
+ prompt = body["prompt"]
143
+ else:
144
+ return JSONResponse(
145
+ status_code=400,
146
+ content={"error": "Invalid request format", "status": "failed"}
147
+ )
148
+
149
+ # Process the request
150
+ result = process_grounding(image, prompt)
151
+
152
+ return JSONResponse(content=result)
153
+
154
+ except Exception as e:
155
+ return JSONResponse(
156
+ status_code=500,
157
+ content={"error": f"Internal server error: {str(e)}", "status": "failed"}
158
+ )
159
+
160
+ # Alternative endpoint names for compatibility
161
+ @app.post("/api/ground")
162
+ async def api_ground(request: Request):
163
+ """Alternative endpoint name for compatibility"""
164
+ return await agent_s_grounding(request)
165
+
166
+ @app.post("/predict")
167
+ async def predict(request: Request):
168
+ """Alternative endpoint name for compatibility"""
169
+ return await agent_s_grounding(request)
170
+
171
+ @app.post("/")
172
+ async def root_endpoint(request: Request):
173
+ """Root endpoint for compatibility"""
174
+ return await agent_s_grounding(request)
175
+
176
+ # Create Gradio interface
177
  iface = gr.Interface(
178
  fn=process_grounding,
179
  inputs=[
180
  gr.Image(type="pil", label="Upload Screenshot"),
181
  gr.Textbox(label="Prompt/Goal", placeholder="What do you want to do?")
182
  ],
183
+ outputs=gr.JSON(label="Grounding Results"),
184
  title="UI-TARS Grounding Model",
185
+ description="Upload a screenshot and describe your goal to get grounding results from UI-TARS"
 
186
  )
187
 
188
+ # Mount Gradio app to FastAPI
189
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
190
+
191
+ if __name__ == "__main__":
192
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
requirements.txt CHANGED
@@ -4,4 +4,6 @@ torchvision>=0.15.0
4
  accelerate>=0.20.0
5
  numpy>=1.21.0
6
  Pillow>=9.0.0
7
- gradio>=4.0.0
 
 
 
4
  accelerate>=0.20.0
5
  numpy>=1.21.0
6
  Pillow>=9.0.0
7
+ gradio>=4.0.0
8
+ fastapi>=0.100.0
9
+ uvicorn>=0.20.0