AlekMan commited on
Commit
451e175
·
verified ·
1 Parent(s): 1148122

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -52
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from typing import Literal, Optional, Tuple
2
- from dataclasses import dataclass
3
  from pathlib import Path
4
  import logging
5
 
@@ -7,6 +6,7 @@ import gradio as gr
7
  from omegaconf import OmegaConf
8
  from dacite import Config as DaciteConfig, from_dict
9
  from transformers import GPT2Config, GPT2LMHeadModel
 
10
 
11
  from llm_trainer import LLMTrainer
12
  from xlstm import xLSTMLMModel, xLSTMLMModelConfig
@@ -15,25 +15,10 @@ logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
- @dataclass
19
- class ModelConfig:
20
- name: Literal["xLSTM", "GPT2"]
21
- checkpoint_path: str
22
- config_path: Optional[str] = None
23
 
24
 
25
- MODEL_CONFIGS = {
26
- "GPT2": ModelConfig(
27
- name="GPT2",
28
- checkpoint_path="checkpoints/gpt/cp_3999.pth"
29
- ),
30
- "xLSTM": ModelConfig(
31
- name="xLSTM",
32
- checkpoint_path="checpoints/xlstm/cp_9999.pth",
33
- config_path="research/xlstm_config.yaml"
34
- )
35
- }
36
-
37
  GPT2_CONFIG = GPT2Config(
38
  vocab_size=50304,
39
  n_positions=256,
@@ -43,6 +28,9 @@ GPT2_CONFIG = GPT2Config(
43
  activation_function="gelu"
44
  )
45
 
 
 
 
46
  UI_CONFIG = {
47
  "title": "HSEAI",
48
  "description": "Enter your text below and the AI will continue it.",
@@ -57,6 +45,13 @@ UI_CONFIG = {
57
  }
58
 
59
 
 
 
 
 
 
 
 
60
  class ModelManager:
61
  """Manages model initialization and caching"""
62
 
@@ -64,26 +59,7 @@ class ModelManager:
64
  self._current_trainer: Optional[LLMTrainer] = None
65
  self._current_model: Optional[str] = None
66
 
67
- def _create_gpt2_trainer(self) -> LLMTrainer:
68
- """Create GPT2 trainer instance"""
69
- model = GPT2LMHeadModel(GPT2_CONFIG)
70
- return LLMTrainer(model=model, model_returns_logits=False)
71
-
72
- def _create_xlstm_trainer(self, config_path: str) -> LLMTrainer:
73
- """Create xLSTM trainer instance"""
74
- if not Path(config_path).exists():
75
- raise FileNotFoundError(f"xLSTM config file not found: {config_path}")
76
-
77
- cfg = OmegaConf.load(config_path)
78
- cfg = from_dict(
79
- data_class=xLSTMLMModelConfig,
80
- data=OmegaConf.to_container(cfg),
81
- config=DaciteConfig(strict=True)
82
- )
83
- model = xLSTMLMModel(cfg)
84
- return LLMTrainer(model=model, model_returns_logits=True)
85
-
86
- def get_trainer(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
87
  """Get trainer instance, creating if necessary"""
88
  if self._current_trainer is None or self._current_model != model_name:
89
  logger.info(f"Loading model: {model_name}")
@@ -95,25 +71,18 @@ class ModelManager:
95
 
96
  def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
97
  """Load and initialize model"""
98
- if model_name not in MODEL_CONFIGS:
99
- raise ValueError(f"Invalid model: {model_name}. Valid models: {list(MODEL_CONFIGS.keys())}")
100
-
101
- config = MODEL_CONFIGS[model_name]
102
-
103
  try:
104
  if model_name == "GPT2":
105
- trainer = self._create_gpt2_trainer()
106
  elif model_name == "xLSTM":
107
- trainer = self._create_xlstm_trainer(config.config_path)
 
 
 
 
108
  else:
109
  raise ValueError(f"Unsupported model: {model_name}")
110
 
111
- checkpoint_path = Path(config.checkpoint_path)
112
- if not checkpoint_path.exists():
113
- raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
114
-
115
- logger.info(f"Loading checkpoint: {checkpoint_path}")
116
- trainer.load_checkpoint(str(checkpoint_path))
117
  return trainer
118
 
119
  except Exception as e:
@@ -178,7 +147,7 @@ def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Butto
178
 
179
  with gr.Row():
180
  model_choice = gr.Dropdown(
181
- choices=list(MODEL_CONFIGS.keys()),
182
  value=UI_CONFIG["default_model"],
183
  label="Model",
184
  interactive=True
 
1
  from typing import Literal, Optional, Tuple
 
2
  from pathlib import Path
3
  import logging
4
 
 
6
  from omegaconf import OmegaConf
7
  from dacite import Config as DaciteConfig, from_dict
8
  from transformers import GPT2Config, GPT2LMHeadModel
9
+ from huggingface_hub import PyTorchModelHubMixin
10
 
11
  from llm_trainer import LLMTrainer
12
  from xlstm import xLSTMLMModel, xLSTMLMModelConfig
 
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
+ class xLSTMWrapper(xLSTMLMModel, PyTorchModelHubMixin):
19
+ pass
 
 
 
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  GPT2_CONFIG = GPT2Config(
23
  vocab_size=50304,
24
  n_positions=256,
 
28
  activation_function="gelu"
29
  )
30
 
31
+ XLSTM_CONFIG = OmegaConf.load("xlstm_config.yaml")
32
+ XLSTM_CONFIG = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(XLSTM_CONFIG), config=DaciteConfig(strict=True))
33
+
34
  UI_CONFIG = {
35
  "title": "HSEAI",
36
  "description": "Enter your text below and the AI will continue it.",
 
45
  }
46
 
47
 
48
+ xLSTM = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_XLSTM", config=XLSTM_CONFIG)
49
+ xLSTM_ft = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_XLSTM_FT", config=XLSTM_CONFIG)
50
+ gpt2 = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
51
+ gpt2_lora = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2")
52
+ gpt2_lora.load_adapter("AlekMan/HSE_AI_GPT2_LoRA")
53
+
54
+
55
  class ModelManager:
56
  """Manages model initialization and caching"""
57
 
 
59
  self._current_trainer: Optional[LLMTrainer] = None
60
  self._current_model: Optional[str] = None
61
 
62
+ def get_trainer(self, model_name: Literal["xLSTM", "GPT2", "xLSTM_FT", "GPT2_FT"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  """Get trainer instance, creating if necessary"""
64
  if self._current_trainer is None or self._current_model != model_name:
65
  logger.info(f"Loading model: {model_name}")
 
71
 
72
  def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer:
73
  """Load and initialize model"""
 
 
 
 
 
74
  try:
75
  if model_name == "GPT2":
76
+ trainer = LLMTrainer(model=gpt2, model_returns_logits=False)
77
  elif model_name == "xLSTM":
78
+ trainer = LLMTrainer(model=xLSTM, model_returns_logits=True)
79
+ elif model_name == "GPT2_FT":
80
+ trainer = LLMTrainer(model=gpt2_lora, model_returns_logits=False)
81
+ elif model_name == "xLSTM_FT":
82
+ trainer = LLMTrainer(model=xLSTM_ft, model_returns_logits=True)
83
  else:
84
  raise ValueError(f"Unsupported model: {model_name}")
85
 
 
 
 
 
 
 
86
  return trainer
87
 
88
  except Exception as e:
 
147
 
148
  with gr.Row():
149
  model_choice = gr.Dropdown(
150
+ choices=["GPT2", "GPT2_FT", "xLSTM", "xLSTM_FT"],
151
  value=UI_CONFIG["default_model"],
152
  label="Model",
153
  interactive=True