File size: 8,603 Bytes
b227d2e
 
3648389
 
 
 
 
 
 
b227d2e
3648389
 
 
 
b227d2e
 
3648389
 
 
 
451e175
 
3648389
 
 
 
 
 
 
 
 
 
 
451e175
 
 
3648389
 
 
 
 
 
 
 
 
 
 
 
 
 
86a7ce8
 
451e175
 
 
 
 
3648389
 
 
 
 
 
 
451e175
3648389
 
 
 
 
 
 
 
 
 
 
 
 
451e175
3648389
451e175
 
 
 
 
3648389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451e175
3648389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import os

from typing import Literal, Optional, Tuple
import logging

import gradio as gr
from omegaconf import OmegaConf
from dacite import Config as DaciteConfig, from_dict
from transformers import GPT2Config, GPT2LMHeadModel
from huggingface_hub import PyTorchModelHubMixin, login

from llm_trainer import LLMTrainer
from xlstm import xLSTMLMModel, xLSTMLMModelConfig

login(token=os.getenv('token'))

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class xLSTMWrapper(xLSTMLMModel, PyTorchModelHubMixin):
    pass


GPT2_CONFIG = GPT2Config(
    vocab_size=50304,
    n_positions=256,
    n_embd=768,
    n_layer=12,
    n_head=12,
    activation_function="gelu"
)

XLSTM_CONFIG = OmegaConf.load("xlstm_config.yaml")
XLSTM_CONFIG = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(XLSTM_CONFIG), config=DaciteConfig(strict=True))

UI_CONFIG = {
    "title": "HSEAI",
    "description": "Enter your text below and the AI will continue it.",
    "port": 7860,
    "host": "0.0.0.0",
    "default_model": "xLSTM",
    "max_sequences": 3,
    "default_length": 64,
    "min_length": 16,
    "max_length": 128,
    "length_step": 16
}


xLSTM = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM", config=XLSTM_CONFIG)
xLSTM_ft = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM_FT", config=XLSTM_CONFIG)
gpt2 = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
gpt2_lora = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
gpt2_lora.load_adapter("AlekMan/HSE_AI_GPT2_LoRA")


class ModelManager:
    """Manages model initialization and caching"""
    
    def __init__(self):
        self._current_trainer: Optional[LLMTrainer] = None
        self._current_model: Optional[str] = None
    
    def get_trainer(self, model_name: Literal["xLSTM", "GPT2", "xLSTM_FT", "GPT2_FT"]):
        """Get trainer instance, creating if necessary"""
        if self._current_trainer is None or self._current_model != model_name:
            logger.info(f"Loading model: {model_name}")
            self._current_trainer = self._load_model(model_name)
            self._current_model = model_name
            logger.info(f"Model {model_name} loaded successfully")
        
        return self._current_trainer
    
    def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
        """Load and initialize model"""
        try:
            if model_name == "GPT2":
                trainer = LLMTrainer(model=gpt2, model_returns_logits=False)
            elif model_name == "xLSTM":
                trainer = LLMTrainer(model=xLSTM, model_returns_logits=True)
            elif model_name == "GPT2_FT":
                trainer = LLMTrainer(model=gpt2_lora, model_returns_logits=False)
            elif model_name == "xLSTM_FT":
                trainer = LLMTrainer(model=xLSTM_ft, model_returns_logits=True)
            else:
                raise ValueError(f"Unsupported model: {model_name}")
            
            return trainer
            
        except Exception as e:
            logger.error(f"Failed to load model {model_name}: {e}")
            raise RuntimeError(f"Failed to load model {model_name}: {e}")


model_manager = ModelManager()


def generate_text(
    user_input: str, 
    model_choice: str = UI_CONFIG["default_model"], 
    n_sequences: int = UI_CONFIG["max_sequences"], 
    length: int = UI_CONFIG["default_length"]
) -> Tuple[str, str, str]:
    """Generate text continuations using the selected model"""
    
    if not user_input.strip():
        return "Please enter some text first.", "", ""
    
    try:
        logger.info(f"Generating text with {model_choice}, length: {length}")
        
        trainer = model_manager.get_trainer(model_choice)
        
        continuations = trainer.generate_text(
            prompt=user_input,
            n_return_sequences=n_sequences,
            length=length
        )
        
        results = []
        for i, continuation in enumerate(continuations[:n_sequences]):
            clean_continuation = continuation[len(user_input):].strip()
            if clean_continuation:
                results.append(clean_continuation + "...")
            else:
                results.append("(No continuation generated)")
        
        while len(results) < 3:
            results.append("")
        
        logger.info("Text generation completed successfully")
        return results[0], results[1], results[2]
        
    except Exception as e:
        error_msg = f"Error during generation: {str(e)}"
        logger.error(error_msg)
        return error_msg, "", ""


def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]:
    """Create the input section of the interface"""
    with gr.Column():
        user_input = gr.Textbox(
            label="Enter your text:",
            placeholder="Type your text here...",
            lines=3,
            max_lines=10
        )
        
        with gr.Row():
            model_choice = gr.Dropdown(
                choices=["GPT2", "GPT2_FT", "xLSTM", "xLSTM_FT"],
                value=UI_CONFIG["default_model"],
                label="Model",
                interactive=True
            )
            
            length = gr.Slider(
                minimum=UI_CONFIG["min_length"],
                maximum=UI_CONFIG["max_length"],
                value=UI_CONFIG["default_length"],
                step=UI_CONFIG["length_step"],
                label="Generation Length"
            )
        
        generate_btn = gr.Button("Generate Continuation", variant="primary")
    
    return user_input, model_choice, length, generate_btn


def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]:
    """Create the output section of the interface"""
    gr.Markdown("### Generated Continuations:")
    
    with gr.Row():
        output1 = gr.Textbox(
            label="Continuation 1",
            lines=8,
            max_lines=15,
            interactive=False
        )
        output2 = gr.Textbox(
            label="Continuation 2", 
            lines=8,
            max_lines=15,
            interactive=False
        )
        output3 = gr.Textbox(
            label="Continuation 3",
            lines=8, 
            max_lines=15,
            interactive=False
        )
    
    return output1, output2, output3


def setup_event_handlers(
    user_input: gr.Textbox,
    model_choice: gr.Dropdown, 
    length: gr.Slider,
    generate_btn: gr.Button,
    outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox]
) -> None:
    """Setup event handlers for the interface"""
    inputs = [
        user_input, 
        model_choice, 
        gr.Number(value=UI_CONFIG["max_sequences"], visible=False), 
        length
    ]
    
    generate_btn.click(
        fn=generate_text,
        inputs=inputs,
        outputs=list(outputs)
    )
    
    user_input.submit(
        fn=generate_text,
        inputs=inputs,
        outputs=list(outputs)
    )


def create_interface() -> gr.Blocks:
    """Create and return the Gradio interface"""
    
    with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo:
        gr.Markdown(f"# {UI_CONFIG['title']}")
        gr.Markdown(UI_CONFIG["description"])
        
        with gr.Row():
            user_input, model_choice, length, generate_btn = create_input_section()
        
        outputs = create_output_section()
        
        setup_event_handlers(user_input, model_choice, length, generate_btn, outputs)
    
    return demo


def initialize_model_on_startup() -> None:
    """Initialize the default model on startup"""
    try:
        logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...")
        model_manager.get_trainer(UI_CONFIG["default_model"])
        logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!")
    except Exception as e:
        logger.warning(f"Could not initialize model on startup: {e}")
        logger.info("Model will be initialized when first used.")


def main() -> None:
    """Main function to launch the Gradio app"""
    logger.info(f"Starting {UI_CONFIG['title']} application...")
    
    initialize_model_on_startup()
    
    demo = create_interface()
    logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}")
    
    demo.launch(
        server_name=UI_CONFIG["host"],
        server_port=UI_CONFIG["port"],
        share=False,
        show_error=True
    )


if __name__ == "__main__":
    main()