Clarification on two-head loss weighting vs inference output (head_2 upweighted, head_1 reported)

#15
by KirillShmilovich - opened

@sb-sb -- i was hoping that i could get come clarity whether some train/inference behavior:

Questions:

  1. Is the strong head_2 weighting intentional?
  2. If yes, what is the motivation for reporting only head_1 at inference?
  3. For downstream use, should users prefer head_1, head_2, or an ensemble/weighted combination?
SandboxAQ org

head_2 is a very simple shallow MLP acting on the trunk embeddings. It does not require the structure, and is small enough that it was trained simultaneously with the main model. The strong head_2 weighting was intentional as a dirty trick to increase the learning rate for that module.

The motivation to only report head_1 is that head_2 is quite a bit worse in practice. It is trained though, and it might make sense to do some ensemble weighting to get a lightly higher performance.

Thanks @maartensandbox !

Out of curiosity: whats the motivation of training these two heads together in the first place? Why not just train head_1?

SandboxAQ org

Because it was essentially zero-cost. There is some time going to loading in datapoints, transferring data to the gpu, backpropping through the main model. Simultaneously training this simple head doesn't really cause a measurable slowdown.

The original hope was that a significantly simpler structure free method may perform better out of distribution, and that an ensembling approach would make sense. In the end we didn't pursue that further.

Sign up or log in to comment