File size: 4,083 Bytes
bafda90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import json

# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
BASE_MODEL = "bigcode/santacoder"  
LORA_PATH = "./lora_output"        
DATASET_FILE = "codesearchnet_lora_dataset.json"   
MAX_TOKEN_LENGTH = 256             
NUM_SAMPLES_TO_PROCESS = 5000      
DEFAULT_EPOCHS = 5 # <--- ¡ENTRENAMIENTO PROFUNDO!

# Variables globales
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None

# ============================================================
# 🚨 LÓGICA DE PRE-PROCESAMIENTO DE DATOS (INTEGRADA) 🚨
# ============================================================
def prepare_codesearchnet():
    """Descarga, procesa y guarda el dataset CodeSearchNet si no existe."""
    if os.path.exists(DATASET_FILE):
        print(f"✅ Dataset '{DATASET_FILE}' ya existe. Cargando directamente.")
        return

    print(f"🔄 Dataset no encontrado. Iniciando descarga y pre-procesamiento de CodeSearchNet ({NUM_SAMPLES_TO_PROCESS} muestras)...")
    
    try:
        raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')

        def format_for_lora(example):
            prompt_text = (
                f"# Descripción: {example['docstring_summary']}\n"
                f"# Completa la siguiente función:\n"
                f"def {example['func_name']}("
            )
            completion_text = example['code']
            
            return {
                "prompt": prompt_text,
                "completion": completion_text
            }

        lora_dataset = raw_csn.map(
            format_for_lora, 
            batched=False, 
            remove_columns=raw_csn["train"].column_names,
        )

        lora_dataset.to_json(DATASET_FILE)
        print(f"✅ Pre-procesamiento completado. {NUM_SAMPLES_TO_PROCESS} ejemplos guardados en '{DATASET_FILE}'.")

    except Exception as e:
        print(f"❌ Error CRÍTICO al descargar/procesar CodeSearchNet. Error: {e}")
        minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
        with open(DATASET_FILE, 'w') as f:
            json.dump(minimal_dataset, f)

# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
# ============================================================

def setup_resources():
    """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
    global tokenizer, lora_model, tokenized_dataset

    # 🛑 1. PREPARA EL DATASET DE CODESEARCHNET ANTES DE INTENTAR CARGARLO
    prepare_codesearchnet()
    
    # 2. Autenticación con Hugging Face
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        login(token=hf_token)

    # 3. Carga del Tokenizer y Modelo Base
    print("\n🔄 Cargando modelo y tokenizer una sola vez...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") 

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 4. Configuración y Aplicación LoRA (PEFT)
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["c_proj", "c_attn"], 
    )
    lora_model = get_peft_model(base_model, peft_config)
    print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")

    # 5. Carga y Tokenización del Dataset
    print(f"📚 Cargando y tokenizando dataset de: {DATASET_FILE}...")
    try:
        raw_dataset = load_dataset("json", data_files=DATASET_FILE)