HSE_AI / app.py
AlekMan's picture
Upload 4 files
3648389 verified
raw
history blame
9.32 kB
from typing import Literal, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import logging
import gradio as gr
from omegaconf import OmegaConf
from dacite import Config as DaciteConfig, from_dict
from transformers import GPT2Config, GPT2LMHeadModel
from llm_trainer import LLMTrainer
from xlstm import xLSTMLMModel, xLSTMLMModelConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
name: Literal["xLSTM", "GPT2"]
checkpoint_path: str
config_path: Optional[str] = None
MODEL_CONFIGS = {
"GPT2": ModelConfig(
name="GPT2",
checkpoint_path="checkpoints/gpt/cp_3999.pth"
),
"xLSTM": ModelConfig(
name="xLSTM",
checkpoint_path="checpoints/xlstm/cp_9999.pth",
config_path="research/xlstm_config.yaml"
)
}
GPT2_CONFIG = GPT2Config(
vocab_size=50304,
n_positions=256,
n_embd=768,
n_layer=12,
n_head=12,
activation_function="gelu"
)
UI_CONFIG = {
"title": "HSEAI",
"description": "Enter your text below and the AI will continue it.",
"port": 7860,
"host": "0.0.0.0",
"default_model": "xLSTM",
"max_sequences": 3,
"default_length": 64,
"min_length": 16,
"max_length": 128,
"length_step": 16
}
class ModelManager:
"""Manages model initialization and caching"""
def __init__(self):
self._current_trainer: Optional[LLMTrainer] = None
self._current_model: Optional[str] = None
def _create_gpt2_trainer(self) -> LLMTrainer:
"""Create GPT2 trainer instance"""
model = GPT2LMHeadModel(GPT2_CONFIG)
return LLMTrainer(model=model, model_returns_logits=False)
def _create_xlstm_trainer(self, config_path: str) -> LLMTrainer:
"""Create xLSTM trainer instance"""
if not Path(config_path).exists():
raise FileNotFoundError(f"xLSTM config file not found: {config_path}")
cfg = OmegaConf.load(config_path)
cfg = from_dict(
data_class=xLSTMLMModelConfig,
data=OmegaConf.to_container(cfg),
config=DaciteConfig(strict=True)
)
model = xLSTMLMModel(cfg)
return LLMTrainer(model=model, model_returns_logits=True)
def get_trainer(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
"""Get trainer instance, creating if necessary"""
if self._current_trainer is None or self._current_model != model_name:
logger.info(f"Loading model: {model_name}")
self._current_trainer = self._load_model(model_name)
self._current_model = model_name
logger.info(f"Model {model_name} loaded successfully")
return self._current_trainer
def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
"""Load and initialize model"""
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Invalid model: {model_name}. Valid models: {list(MODEL_CONFIGS.keys())}")
config = MODEL_CONFIGS[model_name]
try:
if model_name == "GPT2":
trainer = self._create_gpt2_trainer()
elif model_name == "xLSTM":
trainer = self._create_xlstm_trainer(config.config_path)
else:
raise ValueError(f"Unsupported model: {model_name}")
checkpoint_path = Path(config.checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading checkpoint: {checkpoint_path}")
trainer.load_checkpoint(str(checkpoint_path))
return trainer
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise RuntimeError(f"Failed to load model {model_name}: {e}")
model_manager = ModelManager()
def generate_text(
user_input: str,
model_choice: str = UI_CONFIG["default_model"],
n_sequences: int = UI_CONFIG["max_sequences"],
length: int = UI_CONFIG["default_length"]
) -> Tuple[str, str, str]:
"""Generate text continuations using the selected model"""
if not user_input.strip():
return "Please enter some text first.", "", ""
try:
logger.info(f"Generating text with {model_choice}, length: {length}")
trainer = model_manager.get_trainer(model_choice)
continuations = trainer.generate_text(
prompt=user_input,
n_return_sequences=n_sequences,
length=length
)
results = []
for i, continuation in enumerate(continuations[:n_sequences]):
clean_continuation = continuation[len(user_input):].strip()
if clean_continuation:
results.append(clean_continuation + "...")
else:
results.append("(No continuation generated)")
while len(results) < 3:
results.append("")
logger.info("Text generation completed successfully")
return results[0], results[1], results[2]
except Exception as e:
error_msg = f"Error during generation: {str(e)}"
logger.error(error_msg)
return error_msg, "", ""
def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]:
"""Create the input section of the interface"""
with gr.Column():
user_input = gr.Textbox(
label="Enter your text:",
placeholder="Type your text here...",
lines=3,
max_lines=10
)
with gr.Row():
model_choice = gr.Dropdown(
choices=list(MODEL_CONFIGS.keys()),
value=UI_CONFIG["default_model"],
label="Model",
interactive=True
)
length = gr.Slider(
minimum=UI_CONFIG["min_length"],
maximum=UI_CONFIG["max_length"],
value=UI_CONFIG["default_length"],
step=UI_CONFIG["length_step"],
label="Generation Length"
)
generate_btn = gr.Button("Generate Continuation", variant="primary")
return user_input, model_choice, length, generate_btn
def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]:
"""Create the output section of the interface"""
gr.Markdown("### Generated Continuations:")
with gr.Row():
output1 = gr.Textbox(
label="Continuation 1",
lines=8,
max_lines=15,
interactive=False
)
output2 = gr.Textbox(
label="Continuation 2",
lines=8,
max_lines=15,
interactive=False
)
output3 = gr.Textbox(
label="Continuation 3",
lines=8,
max_lines=15,
interactive=False
)
return output1, output2, output3
def setup_event_handlers(
user_input: gr.Textbox,
model_choice: gr.Dropdown,
length: gr.Slider,
generate_btn: gr.Button,
outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox]
) -> None:
"""Setup event handlers for the interface"""
inputs = [
user_input,
model_choice,
gr.Number(value=UI_CONFIG["max_sequences"], visible=False),
length
]
generate_btn.click(
fn=generate_text,
inputs=inputs,
outputs=list(outputs)
)
user_input.submit(
fn=generate_text,
inputs=inputs,
outputs=list(outputs)
)
def create_interface() -> gr.Blocks:
"""Create and return the Gradio interface"""
with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {UI_CONFIG['title']}")
gr.Markdown(UI_CONFIG["description"])
with gr.Row():
user_input, model_choice, length, generate_btn = create_input_section()
outputs = create_output_section()
setup_event_handlers(user_input, model_choice, length, generate_btn, outputs)
return demo
def initialize_model_on_startup() -> None:
"""Initialize the default model on startup"""
try:
logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...")
model_manager.get_trainer(UI_CONFIG["default_model"])
logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!")
except Exception as e:
logger.warning(f"Could not initialize model on startup: {e}")
logger.info("Model will be initialized when first used.")
def main() -> None:
"""Main function to launch the Gradio app"""
logger.info(f"Starting {UI_CONFIG['title']} application...")
initialize_model_on_startup()
demo = create_interface()
logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}")
demo.launch(
server_name=UI_CONFIG["host"],
server_port=UI_CONFIG["port"],
share=False,
show_error=True
)
if __name__ == "__main__":
main()