sterepando commited on
Commit
bbdc54d
·
verified ·
1 Parent(s): 6612fc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -72
app.py CHANGED
@@ -14,7 +14,6 @@ from trl import SFTTrainer
14
  from datasets import Dataset
15
 
16
  # --- КОНФИГУРАЦИЯ ---
17
- # Используем Maincoder-1B (он легкий и не убьет Space по памяти)
18
  MODEL_ID = "Maincode/Maincoder-1B"
19
  OUTPUT_DIR = "mandre_qlora_adapter"
20
  JSON_FILE_NAME = "train_data.json"
@@ -24,19 +23,16 @@ chat_model = None
24
  chat_tokenizer = None
25
 
26
  # ==========================================
27
- # ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА (ИСПРАВЛЕННЫЙ)
28
  # ==========================================
29
 
30
  def generate_json_dataset(files):
31
- # Защита от пустого ввода
32
  if not files:
33
  return None, "❌ Ошибка: Вы не загрузили файлы."
34
 
35
  data_entries = []
36
 
37
- # Gradio 4.x передает список объектов, берем пути корректно
38
  for file_item in files:
39
- # Получаем путь к файлу (обрабатываем разные варианты ввода Gradio)
40
  if isinstance(file_item, str):
41
  file_path = file_item
42
  elif hasattr(file_item, 'name'):
@@ -49,11 +45,10 @@ def generate_json_dataset(files):
49
  content = f.read()
50
  except Exception as e:
51
  print(f"Skipping file {file_path}: {e}")
52
- continue # Пропускаем бинарники или ошибки чтения
53
 
54
  filename = os.path.basename(file_path)
55
 
56
- # Формируем пару Инструкция-Ответ
57
  instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project."
58
  text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>"
59
 
@@ -62,37 +57,30 @@ def generate_json_dataset(files):
62
  if not data_entries:
63
  return None, "❌ Не удалось прочитать ни один текстовый файл."
64
 
65
- # Сохраняем JSON
66
  try:
67
  with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f:
68
  json.dump(data_entries, f, indent=4, ensure_ascii=False)
69
-
70
- # Возвращаем ПУТЬ К ФАЙЛУ (строку) и сообщение (строку)
71
- # Важно: возвращаем абсолютный путь для надежности
72
  abs_path = os.path.abspath(JSON_FILE_NAME)
73
  return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан."
74
  except Exception as e:
75
  return None, f"❌ Ошибка записи JSON: {e}"
76
 
77
  # ==========================================
78
- # ЧАСТЬ 2: ОБУЧЕНИЕ (CPU STREAMING)
79
  # ==========================================
80
 
81
  def train_mandre_ai(file_obj, epochs, lr):
82
- # Проверка наличия файла
83
  if file_obj is None:
84
- # Пытаемся найти файл, если он был создан ранее, но не передан через UI
85
  if os.path.exists(JSON_FILE_NAME):
86
  json_path = JSON_FILE_NAME
87
- yield f"⚠️ Файл не передан в поле, но найден {JSON_FILE_NAME} на диске. Используем его."
88
  else:
89
- yield "❌ Ошибка: Загрузите JSON файл или сгенерируйте его во вкладке 1!"
90
  return
91
  else:
92
- # Gradio может передать объект файла или путь
93
  json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj
94
 
95
- yield f"🚀 Начинаем процесс... (Модель: {MODEL_ID})"
96
 
97
  try:
98
  # 1. Загрузка данных
@@ -100,13 +88,13 @@ def train_mandre_ai(file_obj, epochs, lr):
100
  data = json.load(f)
101
  dataset = Dataset.from_pandas(pd.DataFrame(data))
102
 
103
- yield f"📊 Данные загружены: {len(dataset)} примеров."
104
 
105
- # 2. Токенизатор
106
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
107
  tokenizer.pad_token = tokenizer.eos_token
108
 
109
- # 3. LoRA Config (Оптимизировано для CPU)
110
  peft_config = LoraConfig(
111
  task_type=TaskType.CAUSAL_LM,
112
  inference_mode=False,
@@ -116,7 +104,7 @@ def train_mandre_ai(file_obj, epochs, lr):
116
  target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj']
117
  )
118
 
119
- # 4. Аргументы обучения
120
  training_args = TrainingArguments(
121
  output_dir=OUTPUT_DIR,
122
  num_train_epochs=float(epochs),
@@ -124,23 +112,23 @@ def train_mandre_ai(file_obj, epochs, lr):
124
  gradient_accumulation_steps=4,
125
  learning_rate=float(lr),
126
  weight_decay=0.01,
127
- use_cpu=True, # ПРИНУДИТЕЛЬНО CPU
128
- no_cuda=True, # ОТКЛЮЧИТЬ CUDA
129
- fp16=False, # CPU любит fp32
130
  logging_steps=1,
131
  save_total_limit=1,
132
  push_to_hub=False,
133
- report_to="none" # Отключаем wandb чтобы не спамил
134
  )
135
 
136
- yield "📥 Загрузка модели в память (это может занять минуту)..."
137
 
138
  # 5. Загрузка модели
139
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
140
  model.gradient_checkpointing_enable()
141
  model.enable_input_require_grads()
142
 
143
- yield "🧠 Модель готова. Инициализация тренера..."
144
 
145
  # 6. Trainer
146
  trainer = SFTTrainer(
@@ -153,20 +141,19 @@ def train_mandre_ai(file_obj, epochs, lr):
153
  max_seq_length=1024
154
  )
155
 
156
- yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Процесс пошел. Это будет долго, не закрывайте вкладку."
157
 
158
- # Запуск обучения
159
  trainer.train()
160
 
161
- yield "💾 Сохранение адаптера..."
162
  trainer.model.save_pretrained(OUTPUT_DIR)
163
  tokenizer.save_pretrained(OUTPUT_DIR)
164
 
165
- yield f"✅ УРА! Обучение завершено. Адаптер сохранен в папку '{OUTPUT_DIR}'. Переходите в Чат."
166
 
167
  except Exception as e:
168
  import traceback
169
- yield f"❌ КРИТИЧЕСКАЯ ОШИБКА:\n{traceback.format_exc()}"
170
 
171
  # ==========================================
172
  # ЧАСТЬ 3: ЧАТ
@@ -177,18 +164,19 @@ def load_chat_model():
177
  if chat_model is not None: return "Уже загружено"
178
 
179
  try:
180
- chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
181
 
182
  if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
183
- base = AutoModelForCausalLM.from_pretrained(MODEL_ID)
184
  chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR)
185
- return f"✅ Загружен адаптер QLoRA из {OUTPUT_DIR}!"
186
  else:
187
- chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
188
- return "⚠️ Адаптер не найден (не обучили?). Загружена 'чистая' модель."
189
 
190
  except Exception as e:
191
- return f"Ошибка загрузки: {e}"
192
 
193
  def generate_answer(prompt, history):
194
  if not chat_model:
@@ -198,7 +186,6 @@ def generate_answer(prompt, history):
198
  formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
199
  inputs = chat_tokenizer(formatted_prompt, return_tensors="pt")
200
 
201
- # Генерация
202
  outputs = chat_model.generate(
203
  **inputs,
204
  max_new_tokens=300,
@@ -216,52 +203,33 @@ def generate_answer(prompt, history):
216
  # ИНТЕРФЕЙС
217
  # ==========================================
218
 
219
- with gr.Blocks(title="MandreAI 1B CPU Trainer") as demo:
220
- gr.Markdown("# 🦎 MandreAI (Stable CPU Version)")
221
- gr.Markdown("Версия для обучения на бесплатном процессоре Hugging Face.")
222
 
223
  with gr.Tabs():
224
- # Вкладка 1
225
  with gr.Tab("1. Датасет"):
226
- gr.Markdown("Загрузите файлы `.py`, `.md`, `.java`.")
227
  files_input = gr.File(file_count="multiple", label="Исходные файлы")
228
  btn_gen = gr.Button("Создать JSON", variant="primary")
229
-
230
- # ВАЖНО: Определяем выходы точно так, как возвращает функция (File, Textbox)
231
- json_output = gr.File(label="Готовый датасет (скачается автоматически)")
232
- status_output = gr.Textbox(label="Статус генерации")
233
-
234
- btn_gen.click(
235
- generate_json_dataset,
236
- inputs=[files_input],
237
- outputs=[json_output, status_output]
238
- )
239
 
240
- # Вкладка 2
241
  with gr.Tab("2. Обучение"):
242
- gr.Markdown("Загрузите полученный `.json` файл сюда (или он подхватится автоматически, если создан).")
243
  with gr.Row():
244
- train_file_input = gr.File(label="Файл train_data.json")
245
  epochs = gr.Number(value=3, label="Эпохи", precision=0)
246
- lr = gr.Number(value=2e-4, label="Learning Rate")
247
 
248
  btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop")
249
- # Используем Textbox как лог
250
- log_output = gr.Textbox(label="Лог процесса (обновляется в реальном времени)", lines=10, interactive=False)
251
 
252
- btn_train.click(
253
- train_mandre_ai,
254
- inputs=[train_file_input, epochs, lr],
255
- outputs=[log_output]
256
- )
257
 
258
- # Вкладка 3
259
  with gr.Tab("3. Чат"):
260
  chatbot = gr.Chatbot(label="MandreAI")
261
- msg_input = gr.Textbox(label="Ваш вопрос")
262
  btn_send = gr.Button("Отправить")
263
-
264
  btn_send.click(generate_answer, [msg_input, chatbot], chatbot)
265
 
266
  if __name__ == "__main__":
267
- demo.queue().launch(allowed_paths=["."]) # Разрешаем доступ к локальным файлам
 
14
  from datasets import Dataset
15
 
16
  # --- КОНФИГУРАЦИЯ ---
 
17
  MODEL_ID = "Maincode/Maincoder-1B"
18
  OUTPUT_DIR = "mandre_qlora_adapter"
19
  JSON_FILE_NAME = "train_data.json"
 
23
  chat_tokenizer = None
24
 
25
  # ==========================================
26
+ # ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА
27
  # ==========================================
28
 
29
  def generate_json_dataset(files):
 
30
  if not files:
31
  return None, "❌ Ошибка: Вы не загрузили файлы."
32
 
33
  data_entries = []
34
 
 
35
  for file_item in files:
 
36
  if isinstance(file_item, str):
37
  file_path = file_item
38
  elif hasattr(file_item, 'name'):
 
45
  content = f.read()
46
  except Exception as e:
47
  print(f"Skipping file {file_path}: {e}")
48
+ continue
49
 
50
  filename = os.path.basename(file_path)
51
 
 
52
  instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project."
53
  text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>"
54
 
 
57
  if not data_entries:
58
  return None, "❌ Не удалось прочитать ни один текстовый файл."
59
 
 
60
  try:
61
  with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f:
62
  json.dump(data_entries, f, indent=4, ensure_ascii=False)
 
 
 
63
  abs_path = os.path.abspath(JSON_FILE_NAME)
64
  return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан."
65
  except Exception as e:
66
  return None, f"❌ Ошибка записи JSON: {e}"
67
 
68
  # ==========================================
69
+ # ЧАСТЬ 2: ОБУЧЕНИЕ (ИСПРАВЛЕНО)
70
  # ==========================================
71
 
72
  def train_mandre_ai(file_obj, epochs, lr):
 
73
  if file_obj is None:
 
74
  if os.path.exists(JSON_FILE_NAME):
75
  json_path = JSON_FILE_NAME
76
+ yield f"⚠️ Файл не передан, используем {JSON_FILE_NAME} из прошлой генерации."
77
  else:
78
+ yield "❌ Ошибка: Нет файла с данными!"
79
  return
80
  else:
 
81
  json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj
82
 
83
+ yield f"🚀 Старт обучения {MODEL_ID}..."
84
 
85
  try:
86
  # 1. Загрузка данных
 
88
  data = json.load(f)
89
  dataset = Dataset.from_pandas(pd.DataFrame(data))
90
 
91
+ yield f"📊 Данные: {len(dataset)} строк. Загрузка токенизатора..."
92
 
93
+ # 2. Токенизатор (FIX: use_fast=False чтобы избежать ошибки Rust)
94
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
95
  tokenizer.pad_token = tokenizer.eos_token
96
 
97
+ # 3. LoRA Config
98
  peft_config = LoraConfig(
99
  task_type=TaskType.CAUSAL_LM,
100
  inference_mode=False,
 
104
  target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj']
105
  )
106
 
107
+ # 4. Аргументы
108
  training_args = TrainingArguments(
109
  output_dir=OUTPUT_DIR,
110
  num_train_epochs=float(epochs),
 
112
  gradient_accumulation_steps=4,
113
  learning_rate=float(lr),
114
  weight_decay=0.01,
115
+ use_cpu=True,
116
+ no_cuda=True,
117
+ fp16=False,
118
  logging_steps=1,
119
  save_total_limit=1,
120
  push_to_hub=False,
121
+ report_to="none"
122
  )
123
 
124
+ yield "📥 Загрузка модели (Maincoder-1B)..."
125
 
126
  # 5. Загрузка модели
127
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
128
  model.gradient_checkpointing_enable()
129
  model.enable_input_require_grads()
130
 
131
+ yield "🧠 Инициализация тренера..."
132
 
133
  # 6. Trainer
134
  trainer = SFTTrainer(
 
141
  max_seq_length=1024
142
  )
143
 
144
+ yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Ждите завершения..."
145
 
 
146
  trainer.train()
147
 
148
+ yield "💾 Сохранение..."
149
  trainer.model.save_pretrained(OUTPUT_DIR)
150
  tokenizer.save_pretrained(OUTPUT_DIR)
151
 
152
+ yield f"✅ УСПЕХ! Адаптер в папке '{OUTPUT_DIR}'. Можно чатиться."
153
 
154
  except Exception as e:
155
  import traceback
156
+ yield f"❌ ОШИБКА:\n{traceback.format_exc()}"
157
 
158
  # ==========================================
159
  # ЧАСТЬ 3: ЧАТ
 
164
  if chat_model is not None: return "Уже загружено"
165
 
166
  try:
167
+ # FIX: use_fast=False и здесь
168
+ chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
169
 
170
  if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
171
+ base = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
172
  chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR)
173
+ return f"✅ Адаптер QLoRA загружен!"
174
  else:
175
+ chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
176
+ return "⚠️ Адаптер не найден. Работает чистая модель."
177
 
178
  except Exception as e:
179
+ return f"Ошибка: {e}"
180
 
181
  def generate_answer(prompt, history):
182
  if not chat_model:
 
186
  formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
187
  inputs = chat_tokenizer(formatted_prompt, return_tensors="pt")
188
 
 
189
  outputs = chat_model.generate(
190
  **inputs,
191
  max_new_tokens=300,
 
203
  # ИНТЕРФЕЙС
204
  # ==========================================
205
 
206
+ with gr.Blocks(title="MandreAI Fix") as demo:
207
+ gr.Markdown("# 🦎 MandreAI 1B (CPU Fix)")
 
208
 
209
  with gr.Tabs():
 
210
  with gr.Tab("1. Датасет"):
 
211
  files_input = gr.File(file_count="multiple", label="Исходные файлы")
212
  btn_gen = gr.Button("Создать JSON", variant="primary")
213
+ json_output = gr.File(label="Готовый датасет")
214
+ status_output = gr.Textbox(label="Статус")
215
+ btn_gen.click(generate_json_dataset, inputs=[files_input], outputs=[json_output, status_output])
 
 
 
 
 
 
 
216
 
 
217
  with gr.Tab("2. Обучение"):
 
218
  with gr.Row():
219
+ train_file_input = gr.File(label="train_data.json")
220
  epochs = gr.Number(value=3, label="Эпохи", precision=0)
221
+ lr = gr.Number(value=2e-4, label="LR")
222
 
223
  btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop")
224
+ log_output = gr.Textbox(label="Лог", lines=10)
 
225
 
226
+ btn_train.click(train_mandre_ai, inputs=[train_file_input, epochs, lr], outputs=[log_output])
 
 
 
 
227
 
 
228
  with gr.Tab("3. Чат"):
229
  chatbot = gr.Chatbot(label="MandreAI")
230
+ msg_input = gr.Textbox(label="Вопрос")
231
  btn_send = gr.Button("Отправить")
 
232
  btn_send.click(generate_answer, [msg_input, chatbot], chatbot)
233
 
234
  if __name__ == "__main__":
235
+ demo.queue().launch(allowed_paths=["."])