Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from typing import Literal, Optional, Tuple
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from pathlib import Path
|
| 4 |
import logging
|
| 5 |
|
|
@@ -7,6 +6,7 @@ import gradio as gr
|
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from dacite import Config as DaciteConfig, from_dict
|
| 9 |
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
| 10 |
|
| 11 |
from llm_trainer import LLMTrainer
|
| 12 |
from xlstm import xLSTMLMModel, xLSTMLMModelConfig
|
|
@@ -15,25 +15,10 @@ logging.basicConfig(level=logging.INFO)
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
name: Literal["xLSTM", "GPT2"]
|
| 21 |
-
checkpoint_path: str
|
| 22 |
-
config_path: Optional[str] = None
|
| 23 |
|
| 24 |
|
| 25 |
-
MODEL_CONFIGS = {
|
| 26 |
-
"GPT2": ModelConfig(
|
| 27 |
-
name="GPT2",
|
| 28 |
-
checkpoint_path="checkpoints/gpt/cp_3999.pth"
|
| 29 |
-
),
|
| 30 |
-
"xLSTM": ModelConfig(
|
| 31 |
-
name="xLSTM",
|
| 32 |
-
checkpoint_path="checpoints/xlstm/cp_9999.pth",
|
| 33 |
-
config_path="research/xlstm_config.yaml"
|
| 34 |
-
)
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
GPT2_CONFIG = GPT2Config(
|
| 38 |
vocab_size=50304,
|
| 39 |
n_positions=256,
|
|
@@ -43,6 +28,9 @@ GPT2_CONFIG = GPT2Config(
|
|
| 43 |
activation_function="gelu"
|
| 44 |
)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
UI_CONFIG = {
|
| 47 |
"title": "HSEAI",
|
| 48 |
"description": "Enter your text below and the AI will continue it.",
|
|
@@ -57,6 +45,13 @@ UI_CONFIG = {
|
|
| 57 |
}
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
class ModelManager:
|
| 61 |
"""Manages model initialization and caching"""
|
| 62 |
|
|
@@ -64,26 +59,7 @@ class ModelManager:
|
|
| 64 |
self._current_trainer: Optional[LLMTrainer] = None
|
| 65 |
self._current_model: Optional[str] = None
|
| 66 |
|
| 67 |
-
def
|
| 68 |
-
"""Create GPT2 trainer instance"""
|
| 69 |
-
model = GPT2LMHeadModel(GPT2_CONFIG)
|
| 70 |
-
return LLMTrainer(model=model, model_returns_logits=False)
|
| 71 |
-
|
| 72 |
-
def _create_xlstm_trainer(self, config_path: str) -> LLMTrainer:
|
| 73 |
-
"""Create xLSTM trainer instance"""
|
| 74 |
-
if not Path(config_path).exists():
|
| 75 |
-
raise FileNotFoundError(f"xLSTM config file not found: {config_path}")
|
| 76 |
-
|
| 77 |
-
cfg = OmegaConf.load(config_path)
|
| 78 |
-
cfg = from_dict(
|
| 79 |
-
data_class=xLSTMLMModelConfig,
|
| 80 |
-
data=OmegaConf.to_container(cfg),
|
| 81 |
-
config=DaciteConfig(strict=True)
|
| 82 |
-
)
|
| 83 |
-
model = xLSTMLMModel(cfg)
|
| 84 |
-
return LLMTrainer(model=model, model_returns_logits=True)
|
| 85 |
-
|
| 86 |
-
def get_trainer(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
|
| 87 |
"""Get trainer instance, creating if necessary"""
|
| 88 |
if self._current_trainer is None or self._current_model != model_name:
|
| 89 |
logger.info(f"Loading model: {model_name}")
|
|
@@ -95,25 +71,18 @@ class ModelManager:
|
|
| 95 |
|
| 96 |
def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
|
| 97 |
"""Load and initialize model"""
|
| 98 |
-
if model_name not in MODEL_CONFIGS:
|
| 99 |
-
raise ValueError(f"Invalid model: {model_name}. Valid models: {list(MODEL_CONFIGS.keys())}")
|
| 100 |
-
|
| 101 |
-
config = MODEL_CONFIGS[model_name]
|
| 102 |
-
|
| 103 |
try:
|
| 104 |
if model_name == "GPT2":
|
| 105 |
-
trainer =
|
| 106 |
elif model_name == "xLSTM":
|
| 107 |
-
trainer =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
raise ValueError(f"Unsupported model: {model_name}")
|
| 110 |
|
| 111 |
-
checkpoint_path = Path(config.checkpoint_path)
|
| 112 |
-
if not checkpoint_path.exists():
|
| 113 |
-
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 114 |
-
|
| 115 |
-
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
| 116 |
-
trainer.load_checkpoint(str(checkpoint_path))
|
| 117 |
return trainer
|
| 118 |
|
| 119 |
except Exception as e:
|
|
@@ -178,7 +147,7 @@ def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Butto
|
|
| 178 |
|
| 179 |
with gr.Row():
|
| 180 |
model_choice = gr.Dropdown(
|
| 181 |
-
choices=
|
| 182 |
value=UI_CONFIG["default_model"],
|
| 183 |
label="Model",
|
| 184 |
interactive=True
|
|
|
|
| 1 |
from typing import Literal, Optional, Tuple
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
import logging
|
| 4 |
|
|
|
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from dacite import Config as DaciteConfig, from_dict
|
| 8 |
from transformers import GPT2Config, GPT2LMHeadModel
|
| 9 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 10 |
|
| 11 |
from llm_trainer import LLMTrainer
|
| 12 |
from xlstm import xLSTMLMModel, xLSTMLMModelConfig
|
|
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
|
| 18 |
+
class xLSTMWrapper(xLSTMLMModel, PyTorchModelHubMixin):
|
| 19 |
+
pass
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
GPT2_CONFIG = GPT2Config(
|
| 23 |
vocab_size=50304,
|
| 24 |
n_positions=256,
|
|
|
|
| 28 |
activation_function="gelu"
|
| 29 |
)
|
| 30 |
|
| 31 |
+
XLSTM_CONFIG = OmegaConf.load("xlstm_config.yaml")
|
| 32 |
+
XLSTM_CONFIG = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(XLSTM_CONFIG), config=DaciteConfig(strict=True))
|
| 33 |
+
|
| 34 |
UI_CONFIG = {
|
| 35 |
"title": "HSEAI",
|
| 36 |
"description": "Enter your text below and the AI will continue it.",
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
|
| 48 |
+
xLSTM = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_XLSTM", config=XLSTM_CONFIG)
|
| 49 |
+
xLSTM_ft = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_XLSTM_FT", config=XLSTM_CONFIG)
|
| 50 |
+
gpt2 = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
|
| 51 |
+
gpt2_lora = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
|
| 52 |
+
gpt2_lora.load_adapter("AlekMan/HSE_AI_GPT2_LoRA")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class ModelManager:
|
| 56 |
"""Manages model initialization and caching"""
|
| 57 |
|
|
|
|
| 59 |
self._current_trainer: Optional[LLMTrainer] = None
|
| 60 |
self._current_model: Optional[str] = None
|
| 61 |
|
| 62 |
+
def get_trainer(self, model_name: Literal["xLSTM", "GPT2", "xLSTM_FT", "GPT2_FT"]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""Get trainer instance, creating if necessary"""
|
| 64 |
if self._current_trainer is None or self._current_model != model_name:
|
| 65 |
logger.info(f"Loading model: {model_name}")
|
|
|
|
| 71 |
|
| 72 |
def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
|
| 73 |
"""Load and initialize model"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
try:
|
| 75 |
if model_name == "GPT2":
|
| 76 |
+
trainer = LLMTrainer(model=gpt2, model_returns_logits=False)
|
| 77 |
elif model_name == "xLSTM":
|
| 78 |
+
trainer = LLMTrainer(model=xLSTM, model_returns_logits=True)
|
| 79 |
+
elif model_name == "GPT2_FT":
|
| 80 |
+
trainer = LLMTrainer(model=gpt2_lora, model_returns_logits=False)
|
| 81 |
+
elif model_name == "xLSTM_FT":
|
| 82 |
+
trainer = LLMTrainer(model=xLSTM_ft, model_returns_logits=True)
|
| 83 |
else:
|
| 84 |
raise ValueError(f"Unsupported model: {model_name}")
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return trainer
|
| 87 |
|
| 88 |
except Exception as e:
|
|
|
|
| 147 |
|
| 148 |
with gr.Row():
|
| 149 |
model_choice = gr.Dropdown(
|
| 150 |
+
choices=["GPT2", "GPT2_FT", "xLSTM", "xLSTM_FT"],
|
| 151 |
value=UI_CONFIG["default_model"],
|
| 152 |
label="Model",
|
| 153 |
interactive=True
|