Clarification on two-head loss weighting vs inference output (head_2 upweighted, head_1 reported)
@sb-sb -- i was hoping that i could get come clarity whether some train/inference behavior:
- Two heads are instantiated for training:
- head_2 is heavily upweighted in training loss:
- regression loss: runner.py#L95 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/model/runner.py#L95)
- decoy loss: runner.py#L117 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/model/runner.py#L117)
- (classification branch also uses *50, though currently disabled): runner.py#L109 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/model/runner.py#L109)
- In inference, both heads are computed:
- But only head_1 is used for the reported affinity output:
- selecting head_1: affinity_writer.py#L115 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/callbacks/affinity_writer.py#L115)
- writing output file: affinity_writer.py#L132 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/callbacks/affinity_writer.py#L132)
- Head definitions differ (head_1 structure-aware, head_2 structure-agnostic):
- head_1 include_structure=True: binding_architecture.py#L41 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/configs/binding_architecture.py#L41)
- head_2 include_structure=False: binding_architecture.py#L75 (https://huggingface.co/SandboxAQ/AQAffinity/blob/main/src/aqaffinity/modules/configs/binding_architecture.py#L75)
Questions:
- Is the strong head_2 weighting intentional?
- If yes, what is the motivation for reporting only head_1 at inference?
- For downstream use, should users prefer head_1, head_2, or an ensemble/weighted combination?
head_2 is a very simple shallow MLP acting on the trunk embeddings. It does not require the structure, and is small enough that it was trained simultaneously with the main model. The strong head_2 weighting was intentional as a dirty trick to increase the learning rate for that module.
The motivation to only report head_1 is that head_2 is quite a bit worse in practice. It is trained though, and it might make sense to do some ensemble weighting to get a lightly higher performance.
Thanks @maartensandbox !
Out of curiosity: whats the motivation of training these two heads together in the first place? Why not just train head_1?
Because it was essentially zero-cost. There is some time going to loading in datapoints, transferring data to the gpu, backpropping through the main model. Simultaneously training this simple head doesn't really cause a measurable slowdown.
The original hope was that a significantly simpler structure free method may perform better out of distribution, and that an ensembling approach would make sense. In the end we didn't pursue that further.