File size: 8,680 Bytes
d34f2b5
 
 
 
6612fc2
9791433
 
 
 
 
 
6612fc2
9791433
 
d34f2b5
 
b00440b
9791433
6612fc2
d34f2b5
6612fc2
9791433
 
d34f2b5
 
bbdc54d
d34f2b5
 
9791433
d34f2b5
6612fc2
d34f2b5
9791433
d34f2b5
6612fc2
 
 
 
 
 
 
 
9791433
6612fc2
9791433
6612fc2
 
bbdc54d
d34f2b5
6612fc2
d34f2b5
b00440b
 
9791433
 
d34f2b5
6612fc2
 
d34f2b5
6612fc2
 
 
 
 
 
 
d34f2b5
 
bbdc54d
d34f2b5
 
6612fc2
 
 
 
bbdc54d
6612fc2
bbdc54d
6612fc2
 
 
 
bbdc54d
9791433
d34f2b5
b00440b
6612fc2
9791433
 
6612fc2
bbdc54d
9791433
bbdc54d
 
9791433
 
bbdc54d
9791433
 
 
6612fc2
b00440b
9791433
6612fc2
9791433
 
bbdc54d
9791433
 
 
6612fc2
b00440b
9791433
b00440b
bbdc54d
 
 
9791433
 
 
bbdc54d
9791433
 
bbdc54d
6612fc2
b00440b
bbdc54d
9791433
 
 
bbdc54d
9791433
b00440b
9791433
 
 
 
 
 
 
6612fc2
d34f2b5
 
bbdc54d
6612fc2
9791433
 
bbdc54d
9791433
6612fc2
9791433
bbdc54d
d34f2b5
 
6612fc2
bbdc54d
d34f2b5
9791433
b00440b
9791433
 
 
 
 
 
 
bbdc54d
 
9791433
 
bbdc54d
9791433
bbdc54d
d34f2b5
bbdc54d
 
9791433
 
bbdc54d
d34f2b5
9791433
6612fc2
 
 
d34f2b5
b00440b
9791433
d34f2b5
b00440b
 
 
 
6612fc2
b00440b
 
 
 
 
 
 
d34f2b5
 
6612fc2
d34f2b5
 
bbdc54d
 
9791433
d34f2b5
b00440b
6612fc2
 
bbdc54d
 
 
b00440b
 
d34f2b5
bbdc54d
6612fc2
bbdc54d
d34f2b5
6612fc2
bbdc54d
9791433
bbdc54d
9791433
 
6612fc2
bbdc54d
9791433
6612fc2
d34f2b5
 
bbdc54d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import gradio as gr
import json
import os
import torch
import pandas as pd
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from trl import SFTTrainer
from datasets import Dataset

# --- КОНФИГУРАЦИЯ ---
MODEL_ID = "Maincode/Maincoder-1B"
OUTPUT_DIR = "mandre_qlora_adapter"
JSON_FILE_NAME = "train_data.json"

# Глобальные переменные для чата
chat_model = None
chat_tokenizer = None

# ==========================================
# ЧАСТЬ 1: ГЕНЕРАТОР ДАТАСЕТА
# ==========================================

def generate_json_dataset(files):
    if not files:
        return None, "❌ Ошибка: Вы не загрузили файлы."

    data_entries = []
    
    for file_item in files:
        if isinstance(file_item, str):
            file_path = file_item
        elif hasattr(file_item, 'name'):
            file_path = file_item.name
        else:
            continue

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
        except Exception as e:
            print(f"Skipping file {file_path}: {e}")
            continue 
            
        filename = os.path.basename(file_path)
        
        instruction = f"Analyze the code/text in file '{filename}' from the MandreLib project."
        text = f"### Instruction:\n{instruction}\n\n### Response:\n{content}<|endoftext|>"
        
        data_entries.append({"text": text})

    if not data_entries:
        return None, "❌ Не удалось прочитать ни один текстовый файл."

    try:
        with open(JSON_FILE_NAME, 'w', encoding='utf-8') as f:
            json.dump(data_entries, f, indent=4, ensure_ascii=False)
        abs_path = os.path.abspath(JSON_FILE_NAME)
        return abs_path, f"✅ Готово! Обработано файлов: {len(data_entries)}. Файл {JSON_FILE_NAME} создан."
    except Exception as e:
        return None, f"❌ Ошибка записи JSON: {e}"

# ==========================================
# ЧАСТЬ 2: ОБУЧЕНИЕ (ИСПРАВЛЕНО)
# ==========================================

def train_mandre_ai(file_obj, epochs, lr):
    if file_obj is None:
        if os.path.exists(JSON_FILE_NAME):
            json_path = JSON_FILE_NAME
            yield f"⚠️ Файл не передан, используем {JSON_FILE_NAME} из прошлой генерации."
        else:
            yield "❌ Ошибка: Нет файла с данными!"
            return
    else:
        json_path = file_obj.name if hasattr(file_obj, 'name') else file_obj

    yield f"🚀 Старт обучения {MODEL_ID}..."
    
    try:
        # 1. Загрузка данных
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        dataset = Dataset.from_pandas(pd.DataFrame(data))
        
        yield f"📊 Данные: {len(dataset)} строк. Загрузка токенизатора..."

        # 2. Токенизатор (FIX: use_fast=False чтобы избежать ошибки Rust)
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token

        # 3. LoRA Config
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj']
        )

        # 4. Аргументы
        training_args = TrainingArguments(
            output_dir=OUTPUT_DIR,
            num_train_epochs=float(epochs),
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            learning_rate=float(lr),
            weight_decay=0.01,
            use_cpu=True,           
            no_cuda=True,           
            fp16=False,             
            logging_steps=1,
            save_total_limit=1,
            push_to_hub=False,
            report_to="none"
        )

        yield "📥 Загрузка модели (Maincoder-1B)..."
        
        # 5. Загрузка модели
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
        model.gradient_checkpointing_enable() 
        model.enable_input_require_grads()

        yield "🧠 Инициализация тренера..."

        # 6. Trainer
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            dataset_text_field="text",
            peft_config=peft_config,
            tokenizer=tokenizer,
            max_seq_length=1024
        )

        yield "🔥 ОБУЧЕНИЕ ЗАПУЩЕНО! Ждите завершения..."
        
        trainer.train()
        
        yield "💾 Сохранение..."
        trainer.model.save_pretrained(OUTPUT_DIR)
        tokenizer.save_pretrained(OUTPUT_DIR)
        
        yield f"✅ УСПЕХ! Адаптер в папке '{OUTPUT_DIR}'. Можно чатиться."

    except Exception as e:
        import traceback
        yield f"❌ ОШИБКА:\n{traceback.format_exc()}"

# ==========================================
# ЧАСТЬ 3: ЧАТ
# ==========================================

def load_chat_model():
    global chat_model, chat_tokenizer
    if chat_model is not None: return "Уже загружено"

    try:
        # FIX: use_fast=False и здесь
        chat_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
        
        if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
            base = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
            chat_model = PeftModel.from_pretrained(base, OUTPUT_DIR)
            return f"✅ Адаптер QLoRA загружен!"
        else:
            chat_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
            return "⚠️ Адаптер не найден. Работает чистая модель."
            
    except Exception as e:
        return f"Ошибка: {e}"

def generate_answer(prompt, history):
    if not chat_model: 
        status = load_chat_model()
        if "Ошибка" in status: return status
    
    formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\n"
    inputs = chat_tokenizer(formatted_prompt, return_tensors="pt")
    
    outputs = chat_model.generate(
        **inputs, 
        max_new_tokens=300,
        do_sample=True,
        temperature=0.6,
        top_p=0.95
    )
    
    response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "### Response:" in response:
        return response.split("### Response:")[-1].strip()
    return response

# ==========================================
# ИНТЕРФЕЙС
# ==========================================

with gr.Blocks(title="MandreAI Fix") as demo:
    gr.Markdown("# 🦎 MandreAI 1B (CPU Fix)")
    
    with gr.Tabs():
        with gr.Tab("1. Датасет"):
            files_input = gr.File(file_count="multiple", label="Исходные файлы")
            btn_gen = gr.Button("Создать JSON", variant="primary")
            json_output = gr.File(label="Готовый датасет")
            status_output = gr.Textbox(label="Статус")
            btn_gen.click(generate_json_dataset, inputs=[files_input], outputs=[json_output, status_output])

        with gr.Tab("2. Обучение"):
            with gr.Row():
                train_file_input = gr.File(label="train_data.json")
                epochs = gr.Number(value=3, label="Эпохи", precision=0)
                lr = gr.Number(value=2e-4, label="LR")
            
            btn_train = gr.Button("ЗАПУСТИТЬ ОБУЧЕНИЕ", variant="stop")
            log_output = gr.Textbox(label="Лог", lines=10)
            
            btn_train.click(train_mandre_ai, inputs=[train_file_input, epochs, lr], outputs=[log_output])

        with gr.Tab("3. Чат"):
            chatbot = gr.Chatbot(label="MandreAI")
            msg_input = gr.Textbox(label="Вопрос")
            btn_send = gr.Button("Отправить")
            btn_send.click(generate_answer, [msg_input, chatbot], chatbot)

if __name__ == "__main__":
    demo.queue().launch(allowed_paths=["."])