import gradio as gr import json import os import torch import pandas as pd from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig ) from peft import LoraConfig, get_peft_model, TaskType, PeftModel from trl import SFTTrainer from datasets import Dataset # --- КОНФИГУРАЦИЯ --- MODEL_ID = "Maincode/Maincoder-1B" OUTPUT_DIR = "mandre_qlora_adapter" JSON_FILE_NAME = "train_data.json" # Глобальные переменные для чата chat_model = None chat_tokenizer = None # ========================================== # ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА # ========================================== def generate_json_dataset(files): if not files: return None, "❌ Ошибка: Вы не загрузили файлы." data_entries = [] for file_item in files: if isinstance(file_item, str): file_path = file_item elif hasattr(file_item, 'name'): file_path = file_item.name else: continue try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() except Exception as e: print(f"Skipping file {file_path}: {e}") continue filename = os.path.basename(file_path) instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project." text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>" data_entries.append({"text": text}) if not data_entries: return None, "❌ Не удалось прочитать ни один текстовый файл." try: with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f: json.dump(data_entries, f, indent=4, ensure_ascii=False) abs_path = os.path.abspath(JSON_FILE_NAME) return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан." except Exception as e: return None, f"❌ Ошибка записи JSON: {e}" # ========================================== # ЧАСТЬ 2: ОБУЧЕНИЕ (ИСПРАВЛЕНО) # ========================================== def train_mandre_ai(file_obj, epochs, lr): if file_obj is None: if os.path.exists(JSON_FILE_NAME): json_path = JSON_FILE_NAME yield f"⚠️ Файл не передан, используем {JSON_FILE_NAME} из прошлой генерации." else: yield "❌ Ошибка: Нет файла с данными!" return else: json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj yield f"🚀 Старт обучения {MODEL_ID}..." try: # 1. Загрузка данных with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) dataset = Dataset.from_pandas(pd.DataFrame(data)) yield f"📊 Данные: {len(dataset)} строк. Загрузка токенизатора..." # 2. Токенизатор (FIX: use_fast=False чтобы избежать ошибки Rust) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # 3. LoRA Config peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, lora_alpha=32, lora_dropout=0.05, target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'] ) # 4. Аргументы training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=float(epochs), per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=float(lr), weight_decay=0.01, use_cpu=True, no_cuda=True, fp16=False, logging_steps=1, save_total_limit=1, push_to_hub=False, report_to="none" ) yield "📥 Загрузка модели (Maincoder-1B)..." # 5. Загрузка модели model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) model.gradient_checkpointing_enable() model.enable_input_require_grads() yield "🧠 Инициализация тренера..." # 6. Trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, dataset_text_field="text", peft_config=peft_config, tokenizer=tokenizer, max_seq_length=1024 ) yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Ждите завершения..." trainer.train() yield "💾 Сохранение..." trainer.model.save_pretrained(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) yield f"✅ УСПЕХ! Адаптер в папке '{OUTPUT_DIR}'. Можно чатиться." except Exception as e: import traceback yield f"❌ ОШИБКА:\n{traceback.format_exc()}" # ========================================== # ЧАСТЬ 3: ЧАТ # ========================================== def load_chat_model(): global chat_model, chat_tokenizer if chat_model is not None: return "Уже загружено" try: # FIX: use_fast=False и здесь chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")): base = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR) return f"✅ Адаптер QLoRA загружен!" else: chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True) return "⚠️ Адаптер не найден. Работает чистая модель." except Exception as e: return f"Ошибка: {e}" def generate_answer(prompt, history): if not chat_model: status = load_chat_model() if "Ошибка" in status: return status formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n" inputs = chat_tokenizer(formatted_prompt, return_tensors="pt") outputs = chat_model.generate( **inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.95 ) response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True) if "### Response:" in response: return response.split("### Response:")[-1].strip() return response # ========================================== # ИНТЕРФЕЙС # ========================================== with gr.Blocks(title="MandreAI Fix") as demo: gr.Markdown("# 🦎 MandreAI 1B (CPU Fix)") with gr.Tabs(): with gr.Tab("1. Датасет"): files_input = gr.File(file_count="multiple", label="Исходные файлы") btn_gen = gr.Button("Создать JSON", variant="primary") json_output = gr.File(label="Готовый датасет") status_output = gr.Textbox(label="Статус") btn_gen.click(generate_json_dataset, inputs=[files_input], outputs=[json_output, status_output]) with gr.Tab("2. Обучение"): with gr.Row(): train_file_input = gr.File(label="train_data.json") epochs = gr.Number(value=3, label="Эпохи", precision=0) lr = gr.Number(value=2e-4, label="LR") btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop") log_output = gr.Textbox(label="Лог", lines=10) btn_train.click(train_mandre_ai, inputs=[train_file_input, epochs, lr], outputs=[log_output]) with gr.Tab("3. Чат"): chatbot = gr.Chatbot(label="MandreAI") msg_input = gr.Textbox(label="Вопрос") btn_send = gr.Button("Отправить") btn_send.click(generate_answer, [msg_input, chatbot], chatbot) if __name__ == "__main__": demo.queue().launch(allowed_paths=["."])