predict fix
Browse files- .dockerignore +22 -0
- predict.py +6 -3
.dockerignore
CHANGED
|
@@ -19,3 +19,25 @@ coverage.xml
|
|
| 19 |
.mypy_cache
|
| 20 |
.pytest_cache
|
| 21 |
.hypothesis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
.mypy_cache
|
| 20 |
.pytest_cache
|
| 21 |
.hypothesis
|
| 22 |
+
|
| 23 |
+
# generated by replicate/cog
|
| 24 |
+
__pycache__
|
| 25 |
+
*.pyc
|
| 26 |
+
*.pyo
|
| 27 |
+
*.pyd
|
| 28 |
+
.Python
|
| 29 |
+
env
|
| 30 |
+
pip-log.txt
|
| 31 |
+
pip-delete-this-directory.txt
|
| 32 |
+
.tox
|
| 33 |
+
.coverage
|
| 34 |
+
.coverage.*
|
| 35 |
+
.cache
|
| 36 |
+
nosetests.xml
|
| 37 |
+
coverage.xml
|
| 38 |
+
*.cover
|
| 39 |
+
*.log
|
| 40 |
+
.git
|
| 41 |
+
.mypy_cache
|
| 42 |
+
.pytest_cache
|
| 43 |
+
.hypothesis
|
predict.py
CHANGED
|
@@ -2,10 +2,13 @@ import os
|
|
| 2 |
from cog import BasePredictor, Input, Path
|
| 3 |
import torch
|
| 4 |
import json
|
|
|
|
|
|
|
|
|
|
| 5 |
from src.models.model import load_model
|
| 6 |
-
from src.
|
| 7 |
|
| 8 |
-
CHECKPOINT_DIR = "
|
| 9 |
|
| 10 |
class Predictor(BasePredictor):
|
| 11 |
def setup(self):
|
|
@@ -24,7 +27,7 @@ class Predictor(BasePredictor):
|
|
| 24 |
# Load model
|
| 25 |
self.model = load_model(
|
| 26 |
self.config['num_classes'],
|
| 27 |
-
os.path.join(CHECKPOINT_DIR, "
|
| 28 |
self.device,
|
| 29 |
self.config['clip_model']
|
| 30 |
)
|
|
|
|
| 2 |
from cog import BasePredictor, Input, Path
|
| 3 |
import torch
|
| 4 |
import json
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
+
|
| 8 |
from src.models.model import load_model
|
| 9 |
+
from src.dataset.video_utils import create_transform, extract_frames
|
| 10 |
|
| 11 |
+
CHECKPOINT_DIR = "checkpoints/"
|
| 12 |
|
| 13 |
class Predictor(BasePredictor):
|
| 14 |
def setup(self):
|
|
|
|
| 27 |
# Load model
|
| 28 |
self.model = load_model(
|
| 29 |
self.config['num_classes'],
|
| 30 |
+
os.path.join(CHECKPOINT_DIR, "weights.ckpt"),
|
| 31 |
self.device,
|
| 32 |
self.config['clip_model']
|
| 33 |
)
|