Sadjad Alikhani commited on
Update lwm_model.py
Browse files- lwm_model.py +1 -47
lwm_model.py
CHANGED
|
@@ -10,11 +10,10 @@ import torch
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import numpy as np
|
| 13 |
-
|
| 14 |
from inference import *
|
| 15 |
from load_data import load_DeepMIMO_data
|
| 16 |
from input_preprocess import *
|
| 17 |
-
|
| 18 |
|
| 19 |
|
| 20 |
ELEMENT_LENGTH = 16
|
|
@@ -111,51 +110,6 @@ class EncoderLayer(nn.Module):
|
|
| 111 |
attn_outputs = self.norm(attn_outputs)
|
| 112 |
enc_outputs = self.pos_ffn(attn_outputs)
|
| 113 |
return enc_outputs, attn
|
| 114 |
-
|
| 115 |
-
# class LWM(torch.nn.Module):
|
| 116 |
-
# def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|
| 117 |
-
# super().__init__()
|
| 118 |
-
|
| 119 |
-
# self.embedding = Embedding(element_length, d_model, max_len)
|
| 120 |
-
# self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
|
| 121 |
-
# self.linear = nn.Linear(d_model, d_model)
|
| 122 |
-
# self.norm = LayerNormalization(d_model)
|
| 123 |
-
|
| 124 |
-
# embed_weight = self.embedding.proj.weight
|
| 125 |
-
# d_model, n_dim = embed_weight.size()
|
| 126 |
-
# self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
| 127 |
-
# self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
|
| 128 |
-
# self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
| 129 |
-
|
| 130 |
-
# @classmethod
|
| 131 |
-
# def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
|
| 132 |
-
# # Define model
|
| 133 |
-
# model = cls().to(device)
|
| 134 |
-
|
| 135 |
-
# # Download the model weights (from a remote or local repository)
|
| 136 |
-
# ckpt_path = f'https://huggingface.co/sadjadalikhani/LWM/resolve/main/{ckpt_name}'
|
| 137 |
-
|
| 138 |
-
# # Load the model weights
|
| 139 |
-
# model.load_state_dict(torch.hub.load_state_dict_from_url(ckpt_path, map_location=device))
|
| 140 |
-
# print(f"Model loaded successfully from {ckpt_path} to {device}")
|
| 141 |
-
|
| 142 |
-
# return model
|
| 143 |
-
|
| 144 |
-
# def forward(self, input_ids, masked_pos):
|
| 145 |
-
# output = self.embedding(input_ids)
|
| 146 |
-
|
| 147 |
-
# for layer in self.layers:
|
| 148 |
-
# output, _ = layer(output)
|
| 149 |
-
|
| 150 |
-
# masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
|
| 151 |
-
# h_masked = torch.gather(output, 1, masked_pos)
|
| 152 |
-
# h_masked = self.norm(F.relu(self.linear(h_masked)))
|
| 153 |
-
# logits_lm = self.decoder(h_masked) + self.decoder_bias
|
| 154 |
-
|
| 155 |
-
# return logits_lm, output
|
| 156 |
-
|
| 157 |
-
from huggingface_hub import hf_hub_download
|
| 158 |
-
import torch
|
| 159 |
|
| 160 |
class LWM(torch.nn.Module):
|
| 161 |
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
import numpy as np
|
|
|
|
| 13 |
from inference import *
|
| 14 |
from load_data import load_DeepMIMO_data
|
| 15 |
from input_preprocess import *
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
|
| 18 |
|
| 19 |
ELEMENT_LENGTH = 16
|
|
|
|
| 110 |
attn_outputs = self.norm(attn_outputs)
|
| 111 |
enc_outputs = self.pos_ffn(attn_outputs)
|
| 112 |
return enc_outputs, attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
class LWM(torch.nn.Module):
|
| 115 |
def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
|