Yusuf commited on
Commit
1cb71bc
·
1 Parent(s): d4a4907

predict on test set & log to clearml

Browse files
testingModel/helpers/evaluation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import CrossEntropyLoss
3
+
4
+
5
+ """
6
+ Evaluates a trained model on a dataloader that returns batches like:
7
+ batch["image"] -> Tensor [B, 3, 256, 256]
8
+ batch["label"] -> Tensor [B]
9
+
10
+ Returns dict:
11
+ { "accuracy": float, "loss": float }
12
+ """
13
+ def make_predictions(model, dataloader, device):
14
+
15
+ model.eval()
16
+ criterion = CrossEntropyLoss()
17
+
18
+ total_loss = 0
19
+ total_correct = 0
20
+ total_samples = 0
21
+
22
+ with torch.no_grad():
23
+ for batch in dataloader:
24
+
25
+ # Move tensors to device
26
+ images = batch["image"].to(device)
27
+ labels = batch["label"].to(device).long()
28
+
29
+ # Forward pass
30
+ outputs = model(images)
31
+ loss = criterion(outputs, labels)
32
+
33
+ total_loss += loss.item() * images.size(0)
34
+ total_correct += (outputs.argmax(dim=1) == labels).sum().item()
35
+ total_samples += labels.size(0)
36
+
37
+ accuracy = total_correct / total_samples
38
+ avg_loss = total_loss / total_samples
39
+
40
+ return {
41
+ "accuracy": accuracy,
42
+ "loss": avg_loss,
43
+ }
testingModel/run_testing.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clearml import Task
2
+ from dataPrep.helpers.clearml_data import extract_latest_data_task
3
+
4
+ import torch
5
+ from models.modelOne import modelOne
6
+ from testingModel.helpers.evaluation import make_predictions
7
+
8
+
9
+ # -------------- Load Data --------------
10
+ project_name = "Small Group Project"
11
+ subset_loaders, full_loaders, data_prep_metadata = extract_latest_data_task(project_name=project_name)
12
+
13
+
14
+ # -------- ClearML Testing Task Setup --------
15
+ testing_task = Task.init(
16
+ project_name="Small Group Project",
17
+ task_name="Model Testing",
18
+ task_type=Task.TaskTypes.testing,
19
+ reuse_last_task_id=False,
20
+ )
21
+
22
+ # Reference the data prep task used
23
+ testing_logger = testing_task.get_logger()
24
+ testing_task.connect(data_prep_metadata, name="data_prep_metadata_READONLY")
25
+
26
+ CLEARML_TRAINING_ID = "5bac154a885b4acbaa07d8588027bb27"
27
+
28
+ # Testing parameters - Modify these when experimenting
29
+ testing_config = {
30
+ "model_train_id": CLEARML_TRAINING_ID,
31
+ "num_classes": 39,
32
+ "model_path": "best_model.pt",
33
+ }
34
+ testing_task.connect(testing_config)
35
+
36
+ # Load the model weights from ClearML training task
37
+ training_task = Task.get_task(task_id=testing_config["model_train_id"])
38
+ model_artifact = training_task.artifacts.get("best_model")
39
+ model_path = model_artifact.get_local_copy()
40
+
41
+ model = modelOne()
42
+ state_dict = torch.load(model_path, map_location="cpu") # Load to CPU first
43
+ model.load_state_dict(state_dict)
44
+ model.eval() # set dropout & batch norm layers to eval mode
45
+
46
+ # Move model to GPU if available
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ model.to(device)
49
+
50
+
51
+ # -------------------- Test model on test set --------------------
52
+ testing_logger.report_text("Starting evaluation on TEST SUBSET...\n")
53
+ test_subset = subset_loaders['test']
54
+
55
+ subset_results = make_predictions(model, test_subset, device)
56
+
57
+
58
+ # Accuracy & Loss logging
59
+ testing_logger.report_single_value(name="Test Subset Accuracy", value=subset_results["accuracy"])
60
+ testing_logger.report_single_value(name="Test Subset Loss", value=subset_results["loss"])
61
+
62
+
63
+ # --------- Complete -----------------
64
+ print("\n------ Testing Complete ------")
65
+ testing_logger.report_text(
66
+ f"TEST SUBSET RESULTS:\n"
67
+ f"Loss: {subset_results['loss']:.4f}\n"
68
+ f"Accuracy: {subset_results['accuracy']:.4f}\n"
69
+ )
70
+ testing_task.close()