File size: 9,020 Bytes
42c095f
fc5f9ce
82a3a32
 
 
 
fc5f9ce
42c095f
4d933ae
82a3a32
4d933ae
d1ca382
b8fe5bf
4d933ae
 
 
 
 
82a3a32
4d933ae
38304ad
4d933ae
 
 
 
90d021c
 
 
 
 
d1ca382
90d021c
 
d1ca382
 
90d021c
 
 
d1ca382
90d021c
 
393d55e
6f36468
 
 
 
 
 
 
 
 
 
 
393d55e
 
 
 
90d021c
6f36468
 
 
393d55e
 
6f36468
 
393d55e
 
 
 
 
90d021c
6f36468
 
393d55e
 
90d021c
 
 
 
 
6f36468
393d55e
90d021c
 
 
6f36468
d1ca382
 
 
90d021c
 
6f36468
d1ca382
 
 
6f36468
 
 
 
90d021c
 
 
6f36468
393d55e
6f36468
 
 
393d55e
4d933ae
 
 
 
82a3a32
4d933ae
82a3a32
4d933ae
 
6f36468
 
 
82a3a32
6f36468
 
82a3a32
 
393d55e
fc5f9ce
82a3a32
4d933ae
82a3a32
 
393d55e
fc5f9ce
82a3a32
 
d1ca382
90d021c
 
 
 
 
 
 
82a3a32
4d933ae
 
 
d1ca382
6f36468
d1ca382
90d021c
d1ca382
6f36468
d1ca382
 
 
 
6f36468
 
d1ca382
4d933ae
d1ca382
6f36468
 
 
 
 
4d933ae
 
6f36468
d1ca382
4d933ae
 
6f36468
4d933ae
 
393d55e
4d933ae
 
6f36468
 
 
 
 
 
 
 
 
 
 
4d933ae
 
 
82a3a32
 
 
6f36468
 
4d933ae
 
82a3a32
6f36468
 
4d933ae
6f36468
 
 
 
4d933ae
6f36468
4d933ae
 
 
6f36468
4d933ae
6f36468
 
4d933ae
 
82a3a32
4d933ae
 
 
393d55e
 
4d933ae
6f36468
 
4d933ae
 
393d55e
4d933ae
 
393d55e
82a3a32
4d933ae
393d55e
4d933ae
b8fe5bf
4d933ae
b8fe5bf
4d933ae
b8fe5bf
82a3a32
 
4d933ae
b8fe5bf
 
82a3a32
4d933ae
b8fe5bf
4d933ae
 
82a3a32
4d933ae
 
90d021c
 
4d933ae
 
82a3a32
 
 
 
 
b8fe5bf
 
6f36468
4d933ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
import torch
import gradio as gr
import json
import logging
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download

# Configuration constants
MODEL_REPO = "Gajendra5490/Scrached_Trained_Model"
CURRENT_USER = "gajendra82"
CURRENT_UTC = "2025-05-06 16:00:41"

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

logger = setup_logging()

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, 1, d_model)  # Changed dimension order to match saved model
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class ImprovedTransformer(torch.nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        max_seq_length=128
    ):
        super().__init__()
        
        self.d_model = d_model
        self.embedding = torch.nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # Main transformer
        self.transformer = torch.nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        
        # Output layer
        self.output_layer = torch.nn.Linear(d_model, vocab_size)
        self.norm = torch.nn.LayerNorm(d_model)
        
    def forward(self, src, tgt):
        # Create padding masks
        src_key_padding_mask = (src == 0).to(src.device)
        tgt_key_padding_mask = (tgt == 0).to(tgt.device)
        
        # Create causal mask for target
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        
        # Embeddings and positional encoding
        src = self.embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float))
        tgt = self.embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float))
        
        src = src.transpose(0, 1)  # Change to time-first
        tgt = tgt.transpose(0, 1)  # Change to time-first
        
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        
        src = src.transpose(0, 1)  # Back to batch-first
        tgt = tgt.transpose(0, 1)  # Back to batch-first
        
        # Transform
        output = self.transformer(
            src,
            tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        # Output processing
        output = self.norm(output)
        return self.output_layer(output)

class ModelInference:
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.logger.info(f"Using device: {self.device}")
        self.load_model()

    def load_model(self):
        try:
            token = os.environ.get('HF_TOKEN')
            if not token:
                raise ValueError("HF_TOKEN not found in environment variables")
            
            # Download files
            self.logger.info(f"Downloading from {MODEL_REPO}")
            model_path = hf_hub_download(
                repo_id=MODEL_REPO,
                filename="model.pt",
                token=token
            )
            
            tokenizer_path = hf_hub_download(
                repo_id=MODEL_REPO,
                filename="tokenizer.json",
                token=token
            )
            
            # Load model data first
            self.logger.info("Loading model data...")
            model_data = torch.load(
                model_path,
                map_location=self.device
            )
            
            # Load tokenizer
            self.logger.info("Loading tokenizer...")
            with open(tokenizer_path, 'r', encoding='utf-8') as f:
                tokenizer_data = json.load(f)
            
            # Get exact vocabulary size from the saved model
            self.vocab = tokenizer_data['vocab']
            vocab_size = 1747  # Exact size from the saved model
            
            # Initialize special tokens to match the saved model
            self.special_tokens = {
                "<user>": vocab_size - 4,
                "<assistant>": vocab_size - 3,
                "<sep>": vocab_size - 2,
                "<eos>": vocab_size - 1
            }
            
            # Initialize model with exact vocab size from saved model
            self.model = ImprovedTransformer(
                vocab_size=vocab_size,  # Use exact size
                d_model=512,
                nhead=8,
                num_encoder_layers=3,
                num_decoder_layers=3,
                dim_feedforward=2048
            ).to(self.device)
            
            # Load state dict
            self.model.load_state_dict(model_data['model_state_dict'])
            self.model.eval()
            
            self.logger.info("Model loaded successfully")
            
        except Exception as e:
            self.logger.error(f"Error loading model: {str(e)}")
            raise

    def encode(self, text):
        tokens = text.split()
        return [self.vocab.get(token, 0) if token not in self.special_tokens 
                else self.special_tokens[token] for token in tokens]

    def decode(self, ids):
        reverse_vocab = {v: k for k, v in self.vocab.items()}
        reverse_special = {v: k for k, v in self.special_tokens.items()}
        return " ".join(reverse_vocab.get(id, reverse_special.get(id, "<unk>")) 
                       for id in ids)

    @torch.no_grad()
    def generate_answer(self, input_text: str) -> str:
        try:
            input_text = input_text.strip()
            self.logger.info(f"Processing input: {input_text}")
            
            # Tokenize
            input_ids = self.encode(f"<user> {input_text} <sep>")
            input_tensor = torch.tensor([input_ids]).to(self.device)
            
            # Initialize response
            response_ids = [self.special_tokens["<assistant>"]]
            response_tensor = torch.tensor([response_ids]).to(self.device)
            
            # Generate
            for _ in range(150):
                output = self.model(input_tensor, response_tensor)
                next_token = output[0, -1].argmax().item()
                
                if next_token == self.special_tokens["<eos>"]:
                    break
                    
                response_ids.append(next_token)
                response_tensor = torch.tensor([response_ids]).to(self.device)
            
            # Decode
            answer = self.decode(response_ids)
            answer = answer.replace("<assistant>", "").replace("<eos>", "").strip()
            
            self.logger.info(f"Generated response: {answer}")
            return answer
            
        except Exception as e:
            self.logger.error(f"Error generating answer: {str(e)}")
            return f"Error generating answer: {str(e)}"

# Initialize model
model = None

def process_input(input_text):
    global model
    try:
        if model is None:
            model = ModelInference()
        return model.generate_answer(input_text)
    except Exception as e:
        logger.error(f"Error processing input: {str(e)}")
        return f"Error: {str(e)}"

# Create Gradio interface
interface = gr.Interface(
    fn=process_input,
    inputs=gr.Textbox(
        label="Input Question",
        placeholder="Enter your question here...",
        lines=2
    ),
    outputs=gr.Textbox(
        label="Model Response",
        lines=4
    ),
    title="Model Inference Interface",
    description=f"""
    Model Repository: {MODEL_REPO}
    Current User: {CURRENT_USER}
    Last Updated: {CURRENT_UTC} UTC
    
    Enter your question and click submit to get a response.
    """,
    theme=gr.themes.Soft(),
    examples=[
        ["What is this about?"],
        ["Can you explain the topic?"],
        ["Give me more details."]
    ]
)

# Launch
interface.launch()