Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -150,16 +150,17 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
|
|
| 150 |
|
| 151 |
@spaces.GPU()
|
| 152 |
def generate_pali(user_emb):
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
return decoded
|
| 164 |
|
| 165 |
|
|
|
|
| 150 |
|
| 151 |
@spaces.GPU()
|
| 152 |
def generate_pali(user_emb):
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
prompt = 'caption en'
|
| 155 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
| 156 |
+
# we need to get im_embs taken in here.
|
| 157 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
| 158 |
+
input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
|
| 159 |
+
model_inputs["input_ids"].to(device),
|
| 160 |
+
model_inputs["attention_mask"].to(device))
|
| 161 |
+
|
| 162 |
+
generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
| 163 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 164 |
return decoded
|
| 165 |
|
| 166 |
|