Spaces:
Running
on
Zero
Running
on
Zero
Darius Morawiec
commited on
Commit
·
57edf3e
1
Parent(s):
abc111d
Refactor model loading and inferring
Browse files
app.py
CHANGED
|
@@ -25,20 +25,14 @@ else:
|
|
| 25 |
|
| 26 |
class spaces:
|
| 27 |
@staticmethod
|
| 28 |
-
def GPU(func
|
| 29 |
def wrapper(*args, **kwargs):
|
| 30 |
return func(*args, **kwargs)
|
| 31 |
|
| 32 |
return wrapper
|
| 33 |
|
| 34 |
|
| 35 |
-
@spaces.GPU
|
| 36 |
-
def dummy():
|
| 37 |
-
return
|
| 38 |
-
|
| 39 |
-
|
| 40 |
# Define constants
|
| 41 |
-
GPU_DURATION = 300
|
| 42 |
EXAMPLES_DIR = Path(__file__).parent / "examples"
|
| 43 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
MODEL_IDS = [
|
|
@@ -199,20 +193,20 @@ with gr.Blocks() as demo:
|
|
| 199 |
).eval()
|
| 200 |
current_processor = AutoProcessor.from_pretrained(model_id)
|
| 201 |
current_model_id = model_id
|
| 202 |
-
|
| 203 |
return current_model, current_processor
|
| 204 |
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
image,
|
| 207 |
model_id: str,
|
| 208 |
system_prompt: str,
|
| 209 |
user_prompt: str,
|
| 210 |
-
max_new_tokens: int
|
| 211 |
-
image_resize: str
|
| 212 |
-
image_target_size: int | None
|
| 213 |
):
|
| 214 |
-
model, processor = load_model(model_id)
|
| 215 |
-
|
| 216 |
base64_image = image_to_base64(
|
| 217 |
resize_image(image, image_target_size)
|
| 218 |
if image_resize == "Yes" and image_target_size
|
|
@@ -247,8 +241,7 @@ with gr.Blocks() as demo:
|
|
| 247 |
)
|
| 248 |
inputs = inputs.to(DEVICE)
|
| 249 |
|
| 250 |
-
|
| 251 |
-
generated_ids = generate(**inputs, max_new_tokens=max_new_tokens)
|
| 252 |
generated_ids_trimmed = [
|
| 253 |
out_ids[len(in_ids) :]
|
| 254 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
@@ -287,6 +280,29 @@ with gr.Blocks() as demo:
|
|
| 287 |
|
| 288 |
return [(image, bboxes), str(json.dumps(output_json))]
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
with gr.Row():
|
| 291 |
with gr.Column():
|
| 292 |
gr.Markdown("## Examples")
|
|
|
|
| 25 |
|
| 26 |
class spaces:
|
| 27 |
@staticmethod
|
| 28 |
+
def GPU(func):
|
| 29 |
def wrapper(*args, **kwargs):
|
| 30 |
return func(*args, **kwargs)
|
| 31 |
|
| 32 |
return wrapper
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Define constants
|
|
|
|
| 36 |
EXAMPLES_DIR = Path(__file__).parent / "examples"
|
| 37 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
MODEL_IDS = [
|
|
|
|
| 193 |
).eval()
|
| 194 |
current_processor = AutoProcessor.from_pretrained(model_id)
|
| 195 |
current_model_id = model_id
|
|
|
|
| 196 |
return current_model, current_processor
|
| 197 |
|
| 198 |
+
@spaces.GPU
|
| 199 |
+
def generate(
|
| 200 |
+
model,
|
| 201 |
+
processor,
|
| 202 |
image,
|
| 203 |
model_id: str,
|
| 204 |
system_prompt: str,
|
| 205 |
user_prompt: str,
|
| 206 |
+
max_new_tokens: int,
|
| 207 |
+
image_resize: str,
|
| 208 |
+
image_target_size: int | None,
|
| 209 |
):
|
|
|
|
|
|
|
| 210 |
base64_image = image_to_base64(
|
| 211 |
resize_image(image, image_target_size)
|
| 212 |
if image_resize == "Yes" and image_target_size
|
|
|
|
| 241 |
)
|
| 242 |
inputs = inputs.to(DEVICE)
|
| 243 |
|
| 244 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
|
|
| 245 |
generated_ids_trimmed = [
|
| 246 |
out_ids[len(in_ids) :]
|
| 247 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
|
| 280 |
|
| 281 |
return [(image, bboxes), str(json.dumps(output_json))]
|
| 282 |
|
| 283 |
+
def run(
|
| 284 |
+
image,
|
| 285 |
+
model_id: str,
|
| 286 |
+
system_prompt: str,
|
| 287 |
+
user_prompt: str,
|
| 288 |
+
max_new_tokens: int = 1024,
|
| 289 |
+
image_resize: str = "Yes",
|
| 290 |
+
image_target_size: int | None = None,
|
| 291 |
+
):
|
| 292 |
+
model, processor = load_model(model_id)
|
| 293 |
+
|
| 294 |
+
return generate(
|
| 295 |
+
model,
|
| 296 |
+
processor,
|
| 297 |
+
image,
|
| 298 |
+
model_id,
|
| 299 |
+
system_prompt,
|
| 300 |
+
user_prompt,
|
| 301 |
+
max_new_tokens,
|
| 302 |
+
image_resize,
|
| 303 |
+
image_target_size,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
with gr.Row():
|
| 307 |
with gr.Column():
|
| 308 |
gr.Markdown("## Examples")
|