SaniaE commited on
Commit
9fef689
·
verified ·
1 Parent(s): cb1efba

parallelized caption generation

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import random
4
  from PIL import Image
@@ -77,6 +78,31 @@ async def startup_event():
77
  SEARCH_MODEL = SentenceTransformer('clip-ViT-B-32')
78
  print("Ensemble is live!")
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @app.post("/generate")
81
  async def generate_endpoint(
82
  file: UploadFile = File(...),
@@ -85,41 +111,16 @@ async def generate_endpoint(
85
  top_p: float = Query(0.9)
86
  ):
87
  image = Image.open(file.file).convert("RGB")
88
- captions = []
89
-
90
- # Randomly select which models to use for the 5 slots
91
  available = list(MODELS.keys())
92
  model_selection = random.choices(available, k=5)
93
-
94
- print("Selected models: ", model_selection)
95
-
96
- for m_name in model_selection:
97
- m_data = MODELS[m_name]
98
- model = m_data["model"]
99
-
100
- if m_name == "vit":
101
- i_proc, t_proc = m_data["processor"]
102
- pixel_values = i_proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
103
- gen_ids = model.generate(
104
- pixel_values=pixel_values, max_length=300, do_sample=True,
105
- temperature=temp, top_k=top_k, top_p=top_p
106
- )
107
- cap = t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0]
108
- else:
109
- proc = m_data["processor"]
110
- pixel_values = proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
111
- gen_ids = model.generate(
112
- pixel_values=pixel_values, max_length=300, do_sample=True,
113
- temperature=temp, top_k=top_k, top_p=top_p
114
- )
115
- cap = proc.batch_decode(gen_ids, skip_special_tokens=True)[0]
116
-
117
- captions.append(cap.strip())
118
-
119
- print("Caption generated: ", cap.strip())
120
-
121
  return {"captions": captions, "mix": model_selection}
122
 
 
123
  @app.post("/ui-tester")
124
  async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
125
  """Matches a user description against an image using CLIP embeddings."""
 
1
  import os
2
+ import asyncio
3
  import torch
4
  import random
5
  from PIL import Image
 
78
  SEARCH_MODEL = SentenceTransformer('clip-ViT-B-32')
79
  print("Ensemble is live!")
80
 
81
+ async def run_inference(m_name, image, temp, top_k, top_p):
82
+ # This runs in a separate thread to avoid blocking the event loop
83
+ return await asyncio.to_thread(_generate_sync, m_name, image, temp, top_k, top_p)
84
+
85
+ def _generate_sync(m_name, image, temp, top_k, top_p):
86
+ m_data = MODELS[m_name]
87
+ model = m_data["model"]
88
+
89
+ if m_name == "vit":
90
+ i_proc, t_proc = m_data["processor"]
91
+ pixel_values = i_proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
92
+ gen_ids = model.generate(
93
+ pixel_values=pixel_values, max_length=300, do_sample=True,
94
+ temperature=temp, top_k=top_k, top_p=top_p
95
+ )
96
+ return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
97
+ else:
98
+ proc = m_data["processor"]
99
+ pixel_values = proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
100
+ gen_ids = model.generate(
101
+ pixel_values=pixel_values, max_length=300, do_sample=True,
102
+ temperature=temp, top_k=top_k, top_p=top_p
103
+ )
104
+ return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
105
+
106
  @app.post("/generate")
107
  async def generate_endpoint(
108
  file: UploadFile = File(...),
 
111
  top_p: float = Query(0.9)
112
  ):
113
  image = Image.open(file.file).convert("RGB")
 
 
 
114
  available = list(MODELS.keys())
115
  model_selection = random.choices(available, k=5)
116
+
117
+ # Create tasks for parallel execution
118
+ tasks = [run_inference(m, image, temp, top_k, top_p) for m in model_selection]
119
+ captions = await asyncio.gather(*tasks)
120
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return {"captions": captions, "mix": model_selection}
122
 
123
+
124
  @app.post("/ui-tester")
125
  async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
126
  """Matches a user description against an image using CLIP embeddings."""