AlekMan commited on
Commit
3648389
·
verified ·
1 Parent(s): f62d6f9

Upload 4 files

Browse files
Files changed (3) hide show
  1. app.py +301 -0
  2. requirements.txt +0 -0
  3. research/xlstm_config.yaml +20 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ import logging
5
+
6
+ import gradio as gr
7
+ from omegaconf import OmegaConf
8
+ from dacite import Config as DaciteConfig, from_dict
9
+ from transformers import GPT2Config, GPT2LMHeadModel
10
+
11
+ from llm_trainer import LLMTrainer
12
+ from xlstm import xLSTMLMModel, xLSTMLMModelConfig
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class ModelConfig:
20
+ name: Literal["xLSTM", "GPT2"]
21
+ checkpoint_path: str
22
+ config_path: Optional[str] = None
23
+
24
+
25
+ MODEL_CONFIGS = {
26
+ "GPT2": ModelConfig(
27
+ name="GPT2",
28
+ checkpoint_path="checkpoints/gpt/cp_3999.pth"
29
+ ),
30
+ "xLSTM": ModelConfig(
31
+ name="xLSTM",
32
+ checkpoint_path="checpoints/xlstm/cp_9999.pth",
33
+ config_path="research/xlstm_config.yaml"
34
+ )
35
+ }
36
+
37
+ GPT2_CONFIG = GPT2Config(
38
+ vocab_size=50304,
39
+ n_positions=256,
40
+ n_embd=768,
41
+ n_layer=12,
42
+ n_head=12,
43
+ activation_function="gelu"
44
+ )
45
+
46
+ UI_CONFIG = {
47
+ "title": "HSEAI",
48
+ "description": "Enter your text below and the AI will continue it.",
49
+ "port": 7860,
50
+ "host": "0.0.0.0",
51
+ "default_model": "xLSTM",
52
+ "max_sequences": 3,
53
+ "default_length": 64,
54
+ "min_length": 16,
55
+ "max_length": 128,
56
+ "length_step": 16
57
+ }
58
+
59
+
60
+ class ModelManager:
61
+ """Manages model initialization and caching"""
62
+
63
+ def __init__(self):
64
+ self._current_trainer: Optional[LLMTrainer] = None
65
+ self._current_model: Optional[str] = None
66
+
67
+ def _create_gpt2_trainer(self) -> LLMTrainer:
68
+ """Create GPT2 trainer instance"""
69
+ model = GPT2LMHeadModel(GPT2_CONFIG)
70
+ return LLMTrainer(model=model, model_returns_logits=False)
71
+
72
+ def _create_xlstm_trainer(self, config_path: str) -> LLMTrainer:
73
+ """Create xLSTM trainer instance"""
74
+ if not Path(config_path).exists():
75
+ raise FileNotFoundError(f"xLSTM config file not found: {config_path}")
76
+
77
+ cfg = OmegaConf.load(config_path)
78
+ cfg = from_dict(
79
+ data_class=xLSTMLMModelConfig,
80
+ data=OmegaConf.to_container(cfg),
81
+ config=DaciteConfig(strict=True)
82
+ )
83
+ model = xLSTMLMModel(cfg)
84
+ return LLMTrainer(model=model, model_returns_logits=True)
85
+
86
+ def get_trainer(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
87
+ """Get trainer instance, creating if necessary"""
88
+ if self._current_trainer is None or self._current_model != model_name:
89
+ logger.info(f"Loading model: {model_name}")
90
+ self._current_trainer = self._load_model(model_name)
91
+ self._current_model = model_name
92
+ logger.info(f"Model {model_name} loaded successfully")
93
+
94
+ return self._current_trainer
95
+
96
+ def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
97
+ """Load and initialize model"""
98
+ if model_name not in MODEL_CONFIGS:
99
+ raise ValueError(f"Invalid model: {model_name}. Valid models: {list(MODEL_CONFIGS.keys())}")
100
+
101
+ config = MODEL_CONFIGS[model_name]
102
+
103
+ try:
104
+ if model_name == "GPT2":
105
+ trainer = self._create_gpt2_trainer()
106
+ elif model_name == "xLSTM":
107
+ trainer = self._create_xlstm_trainer(config.config_path)
108
+ else:
109
+ raise ValueError(f"Unsupported model: {model_name}")
110
+
111
+ checkpoint_path = Path(config.checkpoint_path)
112
+ if not checkpoint_path.exists():
113
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
114
+
115
+ logger.info(f"Loading checkpoint: {checkpoint_path}")
116
+ trainer.load_checkpoint(str(checkpoint_path))
117
+ return trainer
118
+
119
+ except Exception as e:
120
+ logger.error(f"Failed to load model {model_name}: {e}")
121
+ raise RuntimeError(f"Failed to load model {model_name}: {e}")
122
+
123
+
124
+ model_manager = ModelManager()
125
+
126
+
127
+ def generate_text(
128
+ user_input: str,
129
+ model_choice: str = UI_CONFIG["default_model"],
130
+ n_sequences: int = UI_CONFIG["max_sequences"],
131
+ length: int = UI_CONFIG["default_length"]
132
+ ) -> Tuple[str, str, str]:
133
+ """Generate text continuations using the selected model"""
134
+
135
+ if not user_input.strip():
136
+ return "Please enter some text first.", "", ""
137
+
138
+ try:
139
+ logger.info(f"Generating text with {model_choice}, length: {length}")
140
+
141
+ trainer = model_manager.get_trainer(model_choice)
142
+
143
+ continuations = trainer.generate_text(
144
+ prompt=user_input,
145
+ n_return_sequences=n_sequences,
146
+ length=length
147
+ )
148
+
149
+ results = []
150
+ for i, continuation in enumerate(continuations[:n_sequences]):
151
+ clean_continuation = continuation[len(user_input):].strip()
152
+ if clean_continuation:
153
+ results.append(clean_continuation + "...")
154
+ else:
155
+ results.append("(No continuation generated)")
156
+
157
+ while len(results) < 3:
158
+ results.append("")
159
+
160
+ logger.info("Text generation completed successfully")
161
+ return results[0], results[1], results[2]
162
+
163
+ except Exception as e:
164
+ error_msg = f"Error during generation: {str(e)}"
165
+ logger.error(error_msg)
166
+ return error_msg, "", ""
167
+
168
+
169
+ def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]:
170
+ """Create the input section of the interface"""
171
+ with gr.Column():
172
+ user_input = gr.Textbox(
173
+ label="Enter your text:",
174
+ placeholder="Type your text here...",
175
+ lines=3,
176
+ max_lines=10
177
+ )
178
+
179
+ with gr.Row():
180
+ model_choice = gr.Dropdown(
181
+ choices=list(MODEL_CONFIGS.keys()),
182
+ value=UI_CONFIG["default_model"],
183
+ label="Model",
184
+ interactive=True
185
+ )
186
+
187
+ length = gr.Slider(
188
+ minimum=UI_CONFIG["min_length"],
189
+ maximum=UI_CONFIG["max_length"],
190
+ value=UI_CONFIG["default_length"],
191
+ step=UI_CONFIG["length_step"],
192
+ label="Generation Length"
193
+ )
194
+
195
+ generate_btn = gr.Button("Generate Continuation", variant="primary")
196
+
197
+ return user_input, model_choice, length, generate_btn
198
+
199
+
200
+ def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]:
201
+ """Create the output section of the interface"""
202
+ gr.Markdown("### Generated Continuations:")
203
+
204
+ with gr.Row():
205
+ output1 = gr.Textbox(
206
+ label="Continuation 1",
207
+ lines=8,
208
+ max_lines=15,
209
+ interactive=False
210
+ )
211
+ output2 = gr.Textbox(
212
+ label="Continuation 2",
213
+ lines=8,
214
+ max_lines=15,
215
+ interactive=False
216
+ )
217
+ output3 = gr.Textbox(
218
+ label="Continuation 3",
219
+ lines=8,
220
+ max_lines=15,
221
+ interactive=False
222
+ )
223
+
224
+ return output1, output2, output3
225
+
226
+
227
+ def setup_event_handlers(
228
+ user_input: gr.Textbox,
229
+ model_choice: gr.Dropdown,
230
+ length: gr.Slider,
231
+ generate_btn: gr.Button,
232
+ outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox]
233
+ ) -> None:
234
+ """Setup event handlers for the interface"""
235
+ inputs = [
236
+ user_input,
237
+ model_choice,
238
+ gr.Number(value=UI_CONFIG["max_sequences"], visible=False),
239
+ length
240
+ ]
241
+
242
+ generate_btn.click(
243
+ fn=generate_text,
244
+ inputs=inputs,
245
+ outputs=list(outputs)
246
+ )
247
+
248
+ user_input.submit(
249
+ fn=generate_text,
250
+ inputs=inputs,
251
+ outputs=list(outputs)
252
+ )
253
+
254
+
255
+ def create_interface() -> gr.Blocks:
256
+ """Create and return the Gradio interface"""
257
+
258
+ with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo:
259
+ gr.Markdown(f"# {UI_CONFIG['title']}")
260
+ gr.Markdown(UI_CONFIG["description"])
261
+
262
+ with gr.Row():
263
+ user_input, model_choice, length, generate_btn = create_input_section()
264
+
265
+ outputs = create_output_section()
266
+
267
+ setup_event_handlers(user_input, model_choice, length, generate_btn, outputs)
268
+
269
+ return demo
270
+
271
+
272
+ def initialize_model_on_startup() -> None:
273
+ """Initialize the default model on startup"""
274
+ try:
275
+ logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...")
276
+ model_manager.get_trainer(UI_CONFIG["default_model"])
277
+ logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!")
278
+ except Exception as e:
279
+ logger.warning(f"Could not initialize model on startup: {e}")
280
+ logger.info("Model will be initialized when first used.")
281
+
282
+
283
+ def main() -> None:
284
+ """Main function to launch the Gradio app"""
285
+ logger.info(f"Starting {UI_CONFIG['title']} application...")
286
+
287
+ initialize_model_on_startup()
288
+
289
+ demo = create_interface()
290
+ logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}")
291
+
292
+ demo.launch(
293
+ server_name=UI_CONFIG["host"],
294
+ server_port=UI_CONFIG["port"],
295
+ share=False,
296
+ show_error=True
297
+ )
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()
requirements.txt ADDED
File without changes
research/xlstm_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 50304
2
+ tie_weights: True
3
+ mlstm_block:
4
+ mlstm:
5
+ conv1d_kernel_size: 4
6
+ qkv_proj_blocksize: 4
7
+ num_heads: 4
8
+ slstm_block:
9
+ slstm:
10
+ # backend: cuda
11
+ num_heads: 4
12
+ conv1d_kernel_size: 4
13
+ bias_init: powerlaw_blockdependent
14
+ feedforward:
15
+ proj_factor: 1.3
16
+ act_fn: gelu
17
+ context_length: 256
18
+ num_blocks: 24
19
+ embedding_dim: 768
20
+ slstm_at: [3, 20]