Spaces:
Paused
Paused
| import gradio as gr | |
| import json | |
| import os | |
| import torch | |
| import pandas as pd | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| BitsAndBytesConfig | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType, PeftModel | |
| from trl import SFTTrainer | |
| from datasets import Dataset | |
| # --- КОНФИГУРАЦИЯ --- | |
| MODEL_ID = "Maincode/Maincoder-1B" | |
| OUTPUT_DIR = "mandre_qlora_adapter" | |
| JSON_FILE_NAME = "train_data.json" | |
| # Глобальные переменные для чата | |
| chat_model = None | |
| chat_tokenizer = None | |
| # ========================================== | |
| # ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА | |
| # ========================================== | |
| def generate_json_dataset(files): | |
| if not files: | |
| return None, "❌ Ошибка: Вы не загрузили файлы." | |
| data_entries = [] | |
| for file_item in files: | |
| if isinstance(file_item, str): | |
| file_path = file_item | |
| elif hasattr(file_item, 'name'): | |
| file_path = file_item.name | |
| else: | |
| continue | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| except Exception as e: | |
| print(f"Skipping file {file_path}: {e}") | |
| continue | |
| filename = os.path.basename(file_path) | |
| instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project." | |
| text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>" | |
| data_entries.append({"text": text}) | |
| if not data_entries: | |
| return None, "❌ Не удалось прочитать ни один текстовый файл." | |
| try: | |
| with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f: | |
| json.dump(data_entries, f, indent=4, ensure_ascii=False) | |
| abs_path = os.path.abspath(JSON_FILE_NAME) | |
| return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан." | |
| except Exception as e: | |
| return None, f"❌ Ошибка записи JSON: {e}" | |
| # ========================================== | |
| # ЧАСТЬ 2: ОБУЧЕНИЕ (ИСПРАВЛЕНО) | |
| # ========================================== | |
| def train_mandre_ai(file_obj, epochs, lr): | |
| if file_obj is None: | |
| if os.path.exists(JSON_FILE_NAME): | |
| json_path = JSON_FILE_NAME | |
| yield f"⚠️ Файл не передан, используем {JSON_FILE_NAME} из прошлой генерации." | |
| else: | |
| yield "❌ Ошибка: Нет файла с данными!" | |
| return | |
| else: | |
| json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj | |
| yield f"🚀 Старт обучения {MODEL_ID}..." | |
| try: | |
| # 1. Загрузка данных | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| dataset = Dataset.from_pandas(pd.DataFrame(data)) | |
| yield f"📊 Данные: {len(dataset)} строк. Загрузка токенизатора..." | |
| # 2. Токенизатор (FIX: use_fast=False чтобы избежать ошибки Rust) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 3. LoRA Config | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'] | |
| ) | |
| # 4. Аргументы | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=float(epochs), | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| learning_rate=float(lr), | |
| weight_decay=0.01, | |
| use_cpu=True, | |
| no_cuda=True, | |
| fp16=False, | |
| logging_steps=1, | |
| save_total_limit=1, | |
| push_to_hub=False, | |
| report_to="none" | |
| ) | |
| yield "📥 Загрузка модели (Maincoder-1B)..." | |
| # 5. Загрузка модели | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model.gradient_checkpointing_enable() | |
| model.enable_input_require_grads() | |
| yield "🧠 Инициализация тренера..." | |
| # 6. Trainer | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| dataset_text_field="text", | |
| peft_config=peft_config, | |
| tokenizer=tokenizer, | |
| max_seq_length=1024 | |
| ) | |
| yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Ждите завершения..." | |
| trainer.train() | |
| yield "💾 Сохранение..." | |
| trainer.model.save_pretrained(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| yield f"✅ УСПЕХ! Адаптер в папке '{OUTPUT_DIR}'. Можно чатиться." | |
| except Exception as e: | |
| import traceback | |
| yield f"❌ ОШИБКА:\n{traceback.format_exc()}" | |
| # ========================================== | |
| # ЧАСТЬ 3: ЧАТ | |
| # ========================================== | |
| def load_chat_model(): | |
| global chat_model, chat_tokenizer | |
| if chat_model is not None: return "Уже загружено" | |
| try: | |
| # FIX: use_fast=False и здесь | |
| chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) | |
| if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): | |
| base = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR) | |
| return f"✅ Адаптер QLoRA загружен!" | |
| else: | |
| chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| return "⚠️ Адаптер не найден. Работает чистая модель." | |
| except Exception as e: | |
| return f"Ошибка: {e}" | |
| def generate_answer(prompt, history): | |
| if not chat_model: | |
| status = load_chat_model() | |
| if "Ошибка" in status: return status | |
| formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n" | |
| inputs = chat_tokenizer(formatted_prompt, return_tensors="pt") | |
| outputs = chat_model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.95 | |
| ) | |
| response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "### Response:" in response: | |
| return response.split("### Response:")[-1].strip() | |
| return response | |
| # ========================================== | |
| # ИНТЕРФЕЙС | |
| # ========================================== | |
| with gr.Blocks(title="MandreAI Fix") as demo: | |
| gr.Markdown("# 🦎 MandreAI 1B (CPU Fix)") | |
| with gr.Tabs(): | |
| with gr.Tab("1. Датасет"): | |
| files_input = gr.File(file_count="multiple", label="Исходные файлы") | |
| btn_gen = gr.Button("Создать JSON", variant="primary") | |
| json_output = gr.File(label="Готовый датасет") | |
| status_output = gr.Textbox(label="Статус") | |
| btn_gen.click(generate_json_dataset, inputs=[files_input], outputs=[json_output, status_output]) | |
| with gr.Tab("2. Обучение"): | |
| with gr.Row(): | |
| train_file_input = gr.File(label="train_data.json") | |
| epochs = gr.Number(value=3, label="Эпохи", precision=0) | |
| lr = gr.Number(value=2e-4, label="LR") | |
| btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop") | |
| log_output = gr.Textbox(label="Лог", lines=10) | |
| btn_train.click(train_mandre_ai, inputs=[train_file_input, epochs, lr], outputs=[log_output]) | |
| with gr.Tab("3. Чат"): | |
| chatbot = gr.Chatbot(label="MandreAI") | |
| msg_input = gr.Textbox(label="Вопрос") | |
| btn_send = gr.Button("Отправить") | |
| btn_send.click(generate_answer, [msg_input, chatbot], chatbot) | |
| if __name__ == "__main__": | |
| demo.queue().launch(allowed_paths=["."]) |