PUBTEST / app.py
sterepando's picture
Update app.py
bbdc54d verified
import gradio as gr
import json
import os
import torch
import pandas as pd
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from trl import SFTTrainer
from datasets import Dataset
# --- КОНФИГУРАЦИЯ ---
MODEL_ID = "Maincode/Maincoder-1B"
OUTPUT_DIR = "mandre_qlora_adapter"
JSON_FILE_NAME = "train_data.json"
# Глобальные переменные для чата
chat_model = None
chat_tokenizer = None
# ==========================================
# ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА
# ==========================================
def generate_json_dataset(files):
if not files:
return None, "❌ Ошибка: Вы не загрузили файлы."
data_entries = []
for file_item in files:
if isinstance(file_item, str):
file_path = file_item
elif hasattr(file_item, 'name'):
file_path = file_item.name
else:
continue
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
print(f"Skipping file {file_path}: {e}")
continue
filename = os.path.basename(file_path)
instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project."
text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>"
data_entries.append({"text": text})
if not data_entries:
return None, "❌ Не удалось прочитать ни один текстовый файл."
try:
with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f:
json.dump(data_entries, f, indent=4, ensure_ascii=False)
abs_path = os.path.abspath(JSON_FILE_NAME)
return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан."
except Exception as e:
return None, f"❌ Ошибка записи JSON: {e}"
# ==========================================
# ЧАСТЬ 2: ОБУЧЕНИЕ (ИСПРАВЛЕНО)
# ==========================================
def train_mandre_ai(file_obj, epochs, lr):
if file_obj is None:
if os.path.exists(JSON_FILE_NAME):
json_path = JSON_FILE_NAME
yield f"⚠️ Файл не передан, используем {JSON_FILE_NAME} из прошлой генерации."
else:
yield "❌ Ошибка: Нет файла с данными!"
return
else:
json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj
yield f"🚀 Старт обучения {MODEL_ID}..."
try:
# 1. Загрузка данных
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
dataset = Dataset.from_pandas(pd.DataFrame(data))
yield f"📊 Данные: {len(dataset)} строк. Загрузка токенизатора..."
# 2. Токенизатор (FIX: use_fast=False чтобы избежать ошибки Rust)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# 3. LoRA Config
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj']
)
# 4. Аргументы
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=float(epochs),
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=float(lr),
weight_decay=0.01,
use_cpu=True,
no_cuda=True,
fp16=False,
logging_steps=1,
save_total_limit=1,
push_to_hub=False,
report_to="none"
)
yield "📥 Загрузка модели (Maincoder-1B)..."
# 5. Загрузка модели
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
yield "🧠 Инициализация тренера..."
# 6. Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
dataset_text_field="text",
peft_config=peft_config,
tokenizer=tokenizer,
max_seq_length=1024
)
yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Ждите завершения..."
trainer.train()
yield "💾 Сохранение..."
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
yield f"✅ УСПЕХ! Адаптер в папке '{OUTPUT_DIR}'. Можно чатиться."
except Exception as e:
import traceback
yield f"❌ ОШИБКА:\n{traceback.format_exc()}"
# ==========================================
# ЧАСТЬ 3: ЧАТ
# ==========================================
def load_chat_model():
global chat_model, chat_tokenizer
if chat_model is not None: return "Уже загружено"
try:
# FIX: use_fast=False и здесь
chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
base = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR)
return f"✅ Адаптер QLoRA загружен!"
else:
chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
return "⚠️ Адаптер не найден. Работает чистая модель."
except Exception as e:
return f"Ошибка: {e}"
def generate_answer(prompt, history):
if not chat_model:
status = load_chat_model()
if "Ошибка" in status: return status
formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
inputs = chat_tokenizer(formatted_prompt, return_tensors="pt")
outputs = chat_model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.6,
top_p=0.95
)
response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### Response:" in response:
return response.split("### Response:")[-1].strip()
return response
# ==========================================
# ИНТЕРФЕЙС
# ==========================================
with gr.Blocks(title="MandreAI Fix") as demo:
gr.Markdown("# 🦎 MandreAI 1B (CPU Fix)")
with gr.Tabs():
with gr.Tab("1. Датасет"):
files_input = gr.File(file_count="multiple", label="Исходные файлы")
btn_gen = gr.Button("Создать JSON", variant="primary")
json_output = gr.File(label="Готовый датасет")
status_output = gr.Textbox(label="Статус")
btn_gen.click(generate_json_dataset, inputs=[files_input], outputs=[json_output, status_output])
with gr.Tab("2. Обучение"):
with gr.Row():
train_file_input = gr.File(label="train_data.json")
epochs = gr.Number(value=3, label="Эпохи", precision=0)
lr = gr.Number(value=2e-4, label="LR")
btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop")
log_output = gr.Textbox(label="Лог", lines=10)
btn_train.click(train_mandre_ai, inputs=[train_file_input, epochs, lr], outputs=[log_output])
with gr.Tab("3. Чат"):
chatbot = gr.Chatbot(label="MandreAI")
msg_input = gr.Textbox(label="Вопрос")
btn_send = gr.Button("Отправить")
btn_send.click(generate_answer, [msg_input, chatbot], chatbot)
if __name__ == "__main__":
demo.queue().launch(allowed_paths=["."])