Spaces:
Build error
Build error
Add truncation
Browse files- app.py +2 -2
- fromage/models.py +6 -1
app.py
CHANGED
|
@@ -212,5 +212,5 @@ with gr.Blocks(css=css) as demo:
|
|
| 212 |
save_button.click(None, [], [], _js=save_js)
|
| 213 |
|
| 214 |
|
| 215 |
-
demo.queue(concurrency_count=1, api_open=False, max_size=16)
|
| 216 |
-
demo.launch(debug=True, server_name="
|
|
|
|
| 212 |
save_button.click(None, [], [], _js=save_js)
|
| 213 |
|
| 214 |
|
| 215 |
+
# demo.queue(concurrency_count=1, api_open=False, max_size=16)
|
| 216 |
+
demo.launch(debug=True, server_name="127.0.0.1")
|
fromage/models.py
CHANGED
|
@@ -525,6 +525,11 @@ class Fromage(nn.Module):
|
|
| 525 |
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
| 526 |
input_embs = torch.cat(input_embs, dim=1)
|
| 527 |
input_ids = torch.cat(input_ids, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
print('L529 called')
|
| 530 |
if num_words == 0:
|
|
@@ -635,7 +640,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
| 635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 637 |
|
| 638 |
-
debug =
|
| 639 |
if debug:
|
| 640 |
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
| 641 |
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
|
|
|
| 525 |
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
| 526 |
input_embs = torch.cat(input_embs, dim=1)
|
| 527 |
input_ids = torch.cat(input_ids, dim=1)
|
| 528 |
+
# Trim to a reasonable max length, for demo purposes.
|
| 529 |
+
start_idx = max(input_embs.shape[1] - 512, 0)
|
| 530 |
+
input_embs = input_embs[:, start_idx:, :]
|
| 531 |
+
input_ids = input_ids[:, start_idx:]
|
| 532 |
+
print('input_embs.shape', input_embs.shape)
|
| 533 |
|
| 534 |
print('L529 called')
|
| 535 |
if num_words == 0:
|
|
|
|
| 640 |
assert len(ret_token_idx) == 1, ret_token_idx
|
| 641 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 642 |
|
| 643 |
+
debug = True
|
| 644 |
if debug:
|
| 645 |
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
| 646 |
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|