vyomie commited on
Commit
a653040
·
verified ·
1 Parent(s): c8e87d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +88 -86
model.py CHANGED
@@ -1,86 +1,88 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
- # --- Liquid Network ---
6
- class LiquidLayer(nn.Module):
7
- def __init__(self, input_dim, output_dim):
8
- super().__init__()
9
- self.W = nn.Parameter(torch.randn(output_dim, input_dim) * 0.02)
10
- self.U = nn.Parameter(torch.randn(output_dim, output_dim) * 0.02)
11
- self.bias = nn.Parameter(torch.zeros(output_dim))
12
- self.act = nn.Tanh()
13
-
14
- def forward(self, x, prev_state=None):
15
- if prev_state is None:
16
- prev_state = torch.zeros(x.size(0), self.W.size(0), device=x.device)
17
- return self.act(x @ self.W.T + prev_state @ self.U.T + self.bias)
18
-
19
- class LiquidNetwork(nn.Module):
20
- def __init__(self, in_dim=768, h_dim=4000, out_dim=768):
21
- super().__init__()
22
- self.l1 = LiquidLayer(in_dim, h_dim)
23
- self.l2 = LiquidLayer(h_dim, h_dim)
24
- self.l3 = LiquidLayer(h_dim, h_dim)
25
- self.l4 = LiquidLayer(h_dim, h_dim)
26
- self.l5 = nn.Linear(h_dim * 4, out_dim)
27
-
28
- def forward(self, x):
29
- h1 = self.l1(x)
30
- h2 = self.l2(h1)
31
- h3 = self.l3(h2)
32
- h4 = self.l4(h3)
33
- return self.l5(torch.cat([h1, h2, h3, h4], dim=-1))
34
-
35
- # --- Bottleneck Autoencoder ---
36
- class BottleneckT5Autoencoder:
37
- def __init__(self, model_path='thesephist/contra-bottleneck-t5-base-wikipedia', device='cpu'):
38
- self.device = device
39
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
40
- self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)
41
- self.model.eval()
42
-
43
- @torch.no_grad()
44
- def embed(self, text: str):
45
- inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
46
- decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
47
- return self.model(
48
- **inputs,
49
- decoder_input_ids=decoder_inputs['input_ids'],
50
- encode_only=True
51
- )[0].squeeze(0).detach()
52
-
53
- @torch.no_grad()
54
- def generate_from_latent(self, latent, max_length=512, temperature=1.0):
55
- dummy_text = '.'
56
- dummy = self.embed(dummy_text)
57
- perturb_vector = latent - dummy
58
- self.model.perturb_vector = perturb_vector
59
- input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids
60
- output = self.model.generate(
61
- input_ids=input_ids,
62
- max_length=max_length,
63
- do_sample=True,
64
- top_p=0.9,
65
- temperature=temperature
66
- )
67
- return self.tokenizer.decode(output[0], skip_special_tokens=True)
68
-
69
- # --- Plug-and-play Pipeline ---
70
- class Pipeline:
71
- def __init__(self, model_name: str, device=None):
72
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
73
- self.autoencoder = BottleneckT5Autoencoder(device=self.device)
74
- self.model = LiquidNetwork().to(self.device)
75
-
76
- state_dict = torch.hub.load_state_dict_from_url(
77
- f"https://huggingface.co/{model_name}/resolve/main/model.pth",
78
- map_location=self.device
79
- )
80
- self.model.load_state_dict(state_dict)
81
- self.model.eval()
82
-
83
- def __call__(self, prompt: str) -> str:
84
- with torch.no_grad():
85
- latent = self.model(self.autoencoder.embed(prompt).unsqueeze(0).to(self.device))
86
- return self.autoencoder.generate_from_latent(latent.squeeze(0))
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # --- Liquid Network ---
6
+ class LiquidLayer(nn.Module):
7
+ def __init__(self, input_dim, output_dim):
8
+ super().__init__()
9
+ self.W = nn.Parameter(torch.randn(output_dim, input_dim) * 0.02)
10
+ self.U = nn.Parameter(torch.randn(output_dim, output_dim) * 0.02)
11
+ self.bias = nn.Parameter(torch.zeros(output_dim))
12
+ self.act = nn.Tanh()
13
+
14
+ def forward(self, x, prev_state=None):
15
+ if prev_state is None:
16
+ prev_state = torch.zeros(x.size(0), self.W.size(0), device=x.device)
17
+ return self.act(x @ self.W.T + prev_state @ self.U.T + self.bias)
18
+
19
+ class LiquidNetwork(nn.Module):
20
+ def __init__(self, in_dim=768, h_dim=4000, out_dim=768):
21
+ super().__init__()
22
+ self.l1 = LiquidLayer(in_dim, h_dim)
23
+ self.l2 = LiquidLayer(h_dim, h_dim)
24
+ self.l3 = LiquidLayer(h_dim, h_dim)
25
+ self.l4 = LiquidLayer(h_dim, h_dim)
26
+ self.l5 = nn.Linear(h_dim * 4, out_dim)
27
+
28
+ def forward(self, x):
29
+ h1 = self.l1(x)
30
+ h2 = self.l2(h1)
31
+ h3 = self.l3(h2)
32
+ h4 = self.l4(h3)
33
+ return self.l5(torch.cat([h1, h2, h3, h4], dim=-1))
34
+
35
+ # --- Bottleneck Autoencoder ---
36
+ class BottleneckT5Autoencoder:
37
+ def __init__(self, model_path='thesephist/contra-bottleneck-t5-base-wikipedia', device='cpu'):
38
+ self.device = device
39
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
40
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)
41
+ self.model.eval()
42
+
43
+ @torch.no_grad()
44
+ def embed(self, text: str):
45
+ inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
46
+ decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
47
+ return self.model(
48
+ **inputs,
49
+ decoder_input_ids=decoder_inputs['input_ids'],
50
+ encode_only=True
51
+ )[0].squeeze(0).detach()
52
+
53
+ @torch.no_grad()
54
+ def generate_from_latent(self, latent, max_length=512, temperature=1.0):
55
+ dummy_text = '.'
56
+ dummy = self.embed(dummy_text)
57
+ perturb_vector = latent - dummy
58
+ self.model.perturb_vector = perturb_vector
59
+ input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids
60
+ output = self.model.generate(
61
+ input_ids=input_ids,
62
+ max_length=max_length,
63
+ do_sample=True,
64
+ top_p=0.9,
65
+ temperature=temperature
66
+ )
67
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
68
+
69
+
70
+ class Pipeline:
71
+ def __init__(self, model_name: str, device=None):
72
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
73
+ self.autoencoder = BottleneckT5Autoencoder(device=self.device)
74
+ self.model = LiquidNetwork().to(self.device)
75
+
76
+ # Download model.pth from Hugging Face Hub
77
+ model_path = hf_hub_download(repo_id=model_name, filename="model.pth")
78
+
79
+ # Load state dict
80
+ state_dict = torch.load(model_path, map_location=self.device)
81
+ self.model.load_state_dict(state_dict)
82
+ self.model.eval()
83
+
84
+ def __call__(self, prompt: str) -> str:
85
+ embedding = self.autoencoder.embed(prompt).unsqueeze(0).to(self.device)
86
+ pred = self.model(embedding)
87
+ output = self.autoencoder.generate_from_latent(pred.squeeze(0))
88
+ return output