Update README.md
Browse files
README.md
CHANGED
|
@@ -28,9 +28,11 @@ To load a pre-trained model ("VisionTransformer.pt"), use the following code sni
|
|
| 28 |
|
| 29 |
```python
|
| 30 |
import clip
|
| 31 |
-
from clip.downstream_task import TaskType
|
| 32 |
import torch # Make sure to import torch
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
device = "cpu" # Change to 'cuda' if you have a GPU
|
| 35 |
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
| 36 |
|
|
@@ -43,4 +45,28 @@ model, _ = clip.load(
|
|
| 43 |
# Load the pre-trained weights
|
| 44 |
model.load_state_dict(torch.load("VisionTransformer.pt"))
|
| 45 |
model.eval() # Set the model to evaluation mode
|
| 46 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
```python
|
| 30 |
import clip
|
|
|
|
| 31 |
import torch # Make sure to import torch
|
| 32 |
|
| 33 |
+
from clip.downstream_task import TaskType
|
| 34 |
+
|
| 35 |
+
|
| 36 |
device = "cpu" # Change to 'cuda' if you have a GPU
|
| 37 |
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
| 38 |
|
|
|
|
| 45 |
# Load the pre-trained weights
|
| 46 |
model.load_state_dict(torch.load("VisionTransformer.pt"))
|
| 47 |
model.eval() # Set the model to evaluation mode
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Loading training data
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import numpy as np
|
| 54 |
+
|
| 55 |
+
from clip.hyperview_data_loader import HyperDataloader, DataReader
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
im_size = 224 # Image size
|
| 59 |
+
num_classes = 4 # Number of classes in the original HYPERVIEW dataset
|
| 60 |
+
|
| 61 |
+
# Paths to training data and ground truth
|
| 62 |
+
train_path = "<TRAIN_PATH>"
|
| 63 |
+
train_gt_path = "<TRAIN_PATH>/train_gt.csv"
|
| 64 |
+
|
| 65 |
+
# Initialize the dataset reader and transformations
|
| 66 |
+
target_index = list(np.arange(num_classes))
|
| 67 |
+
trans_tr, _ = HyperDataloader._init_transform(im_size)
|
| 68 |
+
train_dataset = DataReader(
|
| 69 |
+
database_dir=train_path, label_paths=train_gt_path,
|
| 70 |
+
transform=trans_tr, target_index=target_index
|
| 71 |
+
)
|
| 72 |
+
````
|