| from typing import Literal, Optional, Tuple |
| from dataclasses import dataclass |
| from pathlib import Path |
| import logging |
|
|
| import gradio as gr |
| from omegaconf import OmegaConf |
| from dacite import Config as DaciteConfig, from_dict |
| from transformers import GPT2Config, GPT2LMHeadModel |
|
|
| from llm_trainer import LLMTrainer |
| from xlstm import xLSTMLMModel, xLSTMLMModelConfig |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| name: Literal["xLSTM", "GPT2"] |
| checkpoint_path: str |
| config_path: Optional[str] = None |
|
|
|
|
| MODEL_CONFIGS = { |
| "GPT2": ModelConfig( |
| name="GPT2", |
| checkpoint_path="checkpoints/gpt/cp_3999.pth" |
| ), |
| "xLSTM": ModelConfig( |
| name="xLSTM", |
| checkpoint_path="checpoints/xlstm/cp_9999.pth", |
| config_path="research/xlstm_config.yaml" |
| ) |
| } |
|
|
| GPT2_CONFIG = GPT2Config( |
| vocab_size=50304, |
| n_positions=256, |
| n_embd=768, |
| n_layer=12, |
| n_head=12, |
| activation_function="gelu" |
| ) |
|
|
| UI_CONFIG = { |
| "title": "HSEAI", |
| "description": "Enter your text below and the AI will continue it.", |
| "port": 7860, |
| "host": "0.0.0.0", |
| "default_model": "xLSTM", |
| "max_sequences": 3, |
| "default_length": 64, |
| "min_length": 16, |
| "max_length": 128, |
| "length_step": 16 |
| } |
|
|
|
|
| class ModelManager: |
| """Manages model initialization and caching""" |
| |
| def __init__(self): |
| self._current_trainer: Optional[LLMTrainer] = None |
| self._current_model: Optional[str] = None |
| |
| def _create_gpt2_trainer(self) -> LLMTrainer: |
| """Create GPT2 trainer instance""" |
| model = GPT2LMHeadModel(GPT2_CONFIG) |
| return LLMTrainer(model=model, model_returns_logits=False) |
| |
| def _create_xlstm_trainer(self, config_path: str) -> LLMTrainer: |
| """Create xLSTM trainer instance""" |
| if not Path(config_path).exists(): |
| raise FileNotFoundError(f"xLSTM config file not found: {config_path}") |
| |
| cfg = OmegaConf.load(config_path) |
| cfg = from_dict( |
| data_class=xLSTMLMModelConfig, |
| data=OmegaConf.to_container(cfg), |
| config=DaciteConfig(strict=True) |
| ) |
| model = xLSTMLMModel(cfg) |
| return LLMTrainer(model=model, model_returns_logits=True) |
| |
| def get_trainer(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer: |
| """Get trainer instance, creating if necessary""" |
| if self._current_trainer is None or self._current_model != model_name: |
| logger.info(f"Loading model: {model_name}") |
| self._current_trainer = self._load_model(model_name) |
| self._current_model = model_name |
| logger.info(f"Model {model_name} loaded successfully") |
| |
| return self._current_trainer |
| |
| def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer: |
| """Load and initialize model""" |
| if model_name not in MODEL_CONFIGS: |
| raise ValueError(f"Invalid model: {model_name}. Valid models: {list(MODEL_CONFIGS.keys())}") |
| |
| config = MODEL_CONFIGS[model_name] |
| |
| try: |
| if model_name == "GPT2": |
| trainer = self._create_gpt2_trainer() |
| elif model_name == "xLSTM": |
| trainer = self._create_xlstm_trainer(config.config_path) |
| else: |
| raise ValueError(f"Unsupported model: {model_name}") |
| |
| checkpoint_path = Path(config.checkpoint_path) |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
| |
| logger.info(f"Loading checkpoint: {checkpoint_path}") |
| trainer.load_checkpoint(str(checkpoint_path)) |
| return trainer |
| |
| except Exception as e: |
| logger.error(f"Failed to load model {model_name}: {e}") |
| raise RuntimeError(f"Failed to load model {model_name}: {e}") |
|
|
|
|
| model_manager = ModelManager() |
|
|
|
|
| def generate_text( |
| user_input: str, |
| model_choice: str = UI_CONFIG["default_model"], |
| n_sequences: int = UI_CONFIG["max_sequences"], |
| length: int = UI_CONFIG["default_length"] |
| ) -> Tuple[str, str, str]: |
| """Generate text continuations using the selected model""" |
| |
| if not user_input.strip(): |
| return "Please enter some text first.", "", "" |
| |
| try: |
| logger.info(f"Generating text with {model_choice}, length: {length}") |
| |
| trainer = model_manager.get_trainer(model_choice) |
| |
| continuations = trainer.generate_text( |
| prompt=user_input, |
| n_return_sequences=n_sequences, |
| length=length |
| ) |
| |
| results = [] |
| for i, continuation in enumerate(continuations[:n_sequences]): |
| clean_continuation = continuation[len(user_input):].strip() |
| if clean_continuation: |
| results.append(clean_continuation + "...") |
| else: |
| results.append("(No continuation generated)") |
| |
| while len(results) < 3: |
| results.append("") |
| |
| logger.info("Text generation completed successfully") |
| return results[0], results[1], results[2] |
| |
| except Exception as e: |
| error_msg = f"Error during generation: {str(e)}" |
| logger.error(error_msg) |
| return error_msg, "", "" |
|
|
|
|
| def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]: |
| """Create the input section of the interface""" |
| with gr.Column(): |
| user_input = gr.Textbox( |
| label="Enter your text:", |
| placeholder="Type your text here...", |
| lines=3, |
| max_lines=10 |
| ) |
| |
| with gr.Row(): |
| model_choice = gr.Dropdown( |
| choices=list(MODEL_CONFIGS.keys()), |
| value=UI_CONFIG["default_model"], |
| label="Model", |
| interactive=True |
| ) |
| |
| length = gr.Slider( |
| minimum=UI_CONFIG["min_length"], |
| maximum=UI_CONFIG["max_length"], |
| value=UI_CONFIG["default_length"], |
| step=UI_CONFIG["length_step"], |
| label="Generation Length" |
| ) |
| |
| generate_btn = gr.Button("Generate Continuation", variant="primary") |
| |
| return user_input, model_choice, length, generate_btn |
|
|
|
|
| def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]: |
| """Create the output section of the interface""" |
| gr.Markdown("### Generated Continuations:") |
| |
| with gr.Row(): |
| output1 = gr.Textbox( |
| label="Continuation 1", |
| lines=8, |
| max_lines=15, |
| interactive=False |
| ) |
| output2 = gr.Textbox( |
| label="Continuation 2", |
| lines=8, |
| max_lines=15, |
| interactive=False |
| ) |
| output3 = gr.Textbox( |
| label="Continuation 3", |
| lines=8, |
| max_lines=15, |
| interactive=False |
| ) |
| |
| return output1, output2, output3 |
|
|
|
|
| def setup_event_handlers( |
| user_input: gr.Textbox, |
| model_choice: gr.Dropdown, |
| length: gr.Slider, |
| generate_btn: gr.Button, |
| outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox] |
| ) -> None: |
| """Setup event handlers for the interface""" |
| inputs = [ |
| user_input, |
| model_choice, |
| gr.Number(value=UI_CONFIG["max_sequences"], visible=False), |
| length |
| ] |
| |
| generate_btn.click( |
| fn=generate_text, |
| inputs=inputs, |
| outputs=list(outputs) |
| ) |
| |
| user_input.submit( |
| fn=generate_text, |
| inputs=inputs, |
| outputs=list(outputs) |
| ) |
|
|
|
|
| def create_interface() -> gr.Blocks: |
| """Create and return the Gradio interface""" |
| |
| with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo: |
| gr.Markdown(f"# {UI_CONFIG['title']}") |
| gr.Markdown(UI_CONFIG["description"]) |
| |
| with gr.Row(): |
| user_input, model_choice, length, generate_btn = create_input_section() |
| |
| outputs = create_output_section() |
| |
| setup_event_handlers(user_input, model_choice, length, generate_btn, outputs) |
| |
| return demo |
|
|
|
|
| def initialize_model_on_startup() -> None: |
| """Initialize the default model on startup""" |
| try: |
| logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...") |
| model_manager.get_trainer(UI_CONFIG["default_model"]) |
| logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!") |
| except Exception as e: |
| logger.warning(f"Could not initialize model on startup: {e}") |
| logger.info("Model will be initialized when first used.") |
|
|
|
|
| def main() -> None: |
| """Main function to launch the Gradio app""" |
| logger.info(f"Starting {UI_CONFIG['title']} application...") |
| |
| initialize_model_on_startup() |
| |
| demo = create_interface() |
| logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}") |
| |
| demo.launch( |
| server_name=UI_CONFIG["host"], |
| server_port=UI_CONFIG["port"], |
| share=False, |
| show_error=True |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|