File size: 3,494 Bytes
a653040
 
 
8aee0da
a653040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
# --- Liquid Network ---
class LiquidLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.W = nn.Parameter(torch.randn(output_dim, input_dim) * 0.02)
        self.U = nn.Parameter(torch.randn(output_dim, output_dim) * 0.02)
        self.bias = nn.Parameter(torch.zeros(output_dim))
        self.act = nn.Tanh()

    def forward(self, x, prev_state=None):
        if prev_state is None:
            prev_state = torch.zeros(x.size(0), self.W.size(0), device=x.device)
        return self.act(x @ self.W.T + prev_state @ self.U.T + self.bias)

class LiquidNetwork(nn.Module):
    def __init__(self, in_dim=768, h_dim=4000, out_dim=768):
        super().__init__()
        self.l1 = LiquidLayer(in_dim, h_dim)
        self.l2 = LiquidLayer(h_dim, h_dim)
        self.l3 = LiquidLayer(h_dim, h_dim)
        self.l4 = LiquidLayer(h_dim, h_dim)
        self.l5 = nn.Linear(h_dim * 4, out_dim)

    def forward(self, x):
        h1 = self.l1(x)
        h2 = self.l2(h1)
        h3 = self.l3(h2)
        h4 = self.l4(h3)
        return self.l5(torch.cat([h1, h2, h3, h4], dim=-1))

# --- Bottleneck Autoencoder ---
class BottleneckT5Autoencoder:
    def __init__(self, model_path='thesephist/contra-bottleneck-t5-base-wikipedia', device='cpu'):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)
        self.model.eval()

    @torch.no_grad()
    def embed(self, text: str):
        inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
        decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
        return self.model(
            **inputs,
            decoder_input_ids=decoder_inputs['input_ids'],
            encode_only=True
        )[0].squeeze(0).detach()

    @torch.no_grad()
    def generate_from_latent(self, latent, max_length=512, temperature=1.0):
        dummy_text = '.'
        dummy = self.embed(dummy_text)
        perturb_vector = latent - dummy
        self.model.perturb_vector = perturb_vector
        input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids
        output = self.model.generate(
            input_ids=input_ids,
            max_length=max_length,
            do_sample=True,
            top_p=0.9,
            temperature=temperature
        )
        return self.tokenizer.decode(output[0], skip_special_tokens=True)


class Pipeline:
    def __init__(self, model_name: str, device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.autoencoder = BottleneckT5Autoencoder(device=self.device)
        self.model = LiquidNetwork().to(self.device)

        # Download model.pth from Hugging Face Hub
        model_path = hf_hub_download(repo_id=model_name, filename="model.pth")

        # Load state dict
        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def __call__(self, prompt: str) -> str:
        embedding = self.autoencoder.embed(prompt).unsqueeze(0).to(self.device)
        pred = self.model(embedding)
        output = self.autoencoder.generate_from_latent(pred.squeeze(0))
        return output