File size: 6,864 Bytes
0aef10f
719f624
f76d825
719f624
69630f9
 
12e3c33
 
 
 
69630f9
12e3c33
3fe5c2e
 
69630f9
 
 
 
12e3c33
3fe5c2e
 
b8b84b7
 
 
 
 
 
 
 
 
 
 
69630f9
12e3c33
69630f9
 
 
 
12e3c33
 
69630f9
 
 
 
 
eca2f3b
12e3c33
69630f9
 
 
 
 
0aef10f
12e3c33
69630f9
719f624
0aef10f
69630f9
12e3c33
0aef10f
719f624
 
 
 
 
 
eca2f3b
0aef10f
12e3c33
69630f9
b8b84b7
69630f9
 
b8b84b7
0aef10f
 
b8b84b7
0aef10f
69630f9
 
719f624
f76d825
eca2f3b
 
0aef10f
b8b84b7
69630f9
719f624
69630f9
12e3c33
 
 
b8b84b7
69630f9
 
f76d825
12e3c33
 
3fe5c2e
f76d825
 
69630f9
 
 
 
719f624
 
f76d825
719f624
 
 
12e3c33
 
69630f9
 
12e3c33
69630f9
 
12e3c33
3fe5c2e
719f624
 
 
69630f9
52206e3
69630f9
b8b84b7
69630f9
b8b84b7
69630f9
b8b84b7
 
 
 
 
 
 
719f624
69630f9
719f624
3fe5c2e
69630f9
eca2f3b
719f624
f76d825
0aef10f
f76d825
0aef10f
 
719f624
69630f9
b8b84b7
 
 
69630f9
b8b84b7
 
 
69630f9
b8b84b7
69630f9
b8b84b7
 
 
f76d825
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os, shutil, subprocess, zipfile
from pathlib import Path
import gradio as gr

ROOT   = Path(_file_).resolve().parent
DATA   = ROOT / "dataset.jsonl"             # single-file mode target
LOG    = ROOT / "train.log"
OUT    = ROOT / "trained_model"
ZIP    = ROOT / "trained_model.zip"

# ---------- helpers ----------
def ls_workspace() -> str:
    rows = []
    for p in sorted(ROOT.iterdir(), key=lambda x: (x.is_file(), x.name.lower())):
        try:
            size = p.stat().st_size
        except Exception:
            size = 0
        rows.append(f"{'[DIR]' if p.is_dir() else '     '}\t{size:>10}\t{p.name}")
    return "\n".join(rows) or "(empty)"

def list_models():
    out = []
    for p in ROOT.iterdir():
        if p.is_dir() and (p / "config.json").exists() and (
            (p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists()
        ):
            out.append(str(p))
    if OUT.exists() and str(OUT) not in out:
        out.insert(0, str(OUT))
    return sorted(out)

# ---------- train tab ----------
def upload_dataset(file):
    """
    If user uploads a file -> copy to dataset.jsonl
    If user uploads a folder -> we DO NOT move it, they’ll pass folder path via a textbox if needed.
    """
    if not file:
        return "❌ No file selected.", ls_workspace()
    # If it's a file object, copy to DATA
    if hasattr(file, "name") and os.path.isfile(file.name):
        shutil.copy(file.name, DATA)
        return f"✅ Uploaded → {DATA.name}", ls_workspace()
    return "⚠ Unexpected item; please upload a .jsonl file.", ls_workspace()

def start_training():
    # Clean previous artifacts
    if OUT.exists():
        shutil.rmtree(OUT, ignore_errors=True)
    if ZIP.exists():
        ZIP.unlink(missing_ok=True)
    LOG.write_text("🔥 Training started…\n", encoding="utf-8")

    # Run trainer (blocking) and capture output in train.log
    cmd = [
        "python", str(ROOT / "train.py"),
        "--dataset", str(DATA),                  # For folder-mode, replace DATA with folder path in train.py if you extend UI
        "--output",  str(OUT),
        "--zip_path", str(ZIP),
        "--model_name", "Salesforce/codegen-350M-multi",
        "--epochs", "1",
        "--batch_size", "2",
        "--block_size", "256",
        "--learning_rate", "5e-5",
    ]
    with open(LOG, "a", encoding="utf-8") as lf:
        code = subprocess.Popen(cmd, stdout=lf, stderr=subprocess.STDOUT).wait()

    # Refresh model list & set selection only if it’s present
    models = list_models()
    selected = str(OUT) if OUT.exists() and str(OUT) in models else None
    model_update = gr.update(choices=models, value=selected)

    if code == 0 and ZIP.exists():
        info = f"✅ Training complete. Saved: {OUT.name} | Zip: {ZIP.name}"
        return info, gr.update(value=str(ZIP), visible=True), ls_workspace(), read_logs(), model_update
    else:
        info = f"❌ Training failed (exit {code}). Check logs below."
        return info, gr.update(value=None, visible=False), ls_workspace(), read_logs(), model_update

def read_logs():
    return LOG.read_text(encoding="utf-8")[-20000:] if LOG.exists() else "⏳ Waiting…"

def refresh_download():
    models = list_models()
    return gr.update(value=(str(ZIP) if ZIP.exists() else None), visible=ZIP.exists()), ls_workspace(), gr.update(choices=models)

# ---------- test tab ----------
def import_zip(zfile):
    if not zfile:
        return "❌ No zip selected.", list_models()
    dest = ROOT / "imported_model"
    if dest.exists():
        shutil.rmtree(dest, ignore_errors=True)
    dest.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(zfile.name, "r") as z:
        z.extractall(dest)
    return f"✅ Imported to {dest.name}", list_models()

def generate(model_path, prompt):
    if not model_path:
        return "❌ Select a model."
    if not prompt or not prompt.strip():
        return "❌ Enter a prompt."
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
        tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
        if tok.pad_token_id is None and tok.eos_token_id is not None:
            tok.pad_token = tok.eos_token
        model = AutoModelForCausalLM.from_pretrained(model_path)
        pipe = pipeline("text-generation", model=model, tokenizer=tok)
        out = pipe(
            prompt,
            max_new_tokens=220, do_sample=True, temperature=0.2, top_p=0.9,
            repetition_penalty=1.2, no_repeat_ngram_size=4,
            eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
            truncation=True
        )[0]["generated_text"]
        return out
    except Exception as e:
        return f"❌ Error: {e}"

# ---------- UI ----------
with gr.Blocks(title="Python AI — Train & Test") as app:
    gr.Markdown("## 🧠 Python AI — Train & Test\nTrainer saves & zips. UI only shows existing artifacts.\n")

    # Test tab (declared first so we can update its dropdown from Train tab)
    with gr.Tab("Test"):
        gr.Markdown("### Choose a model folder or upload a .zip, then prompt it")
        refresh_btn = gr.Button("↻ Refresh Model List")
        model_list = gr.Dropdown(choices=list_models(), label="Available AIs", interactive=True)
        zip_in = gr.File(label="Or upload a model .zip", file_types=[".zip"])
        import_status = gr.Textbox(label="Import Status", interactive=False)
        prompt = gr.Textbox(label="Prompt", lines=8, placeholder="### Instruction:\nPython: write a function ...\n### Response:\n")
        go = gr.Button("Generate")
        out = gr.Textbox(label="AI Response", lines=20)

    # Train tab
    with gr.Tab("Train"):
        with gr.Row():
            ds = gr.File(label="📥 Upload JSONL", file_types=[".jsonl"])
            ws = gr.Textbox(label="Workspace", lines=16, value=ls_workspace())
        up_status = gr.Textbox(label="Upload Status", interactive=False)
        start = gr.Button("🚀 Start Training", variant="primary")
        logs = gr.Textbox(label="📜 Training Logs", lines=18, value=read_logs())
        status = gr.Textbox(label="Status", interactive=False)
        download_file = gr.File(label="📦 trained_model.zip", visible=ZIP.exists())
        refresh_dl_btn = gr.Button("Refresh Download")

    # Wiring
    ds.change(upload_dataset, inputs=ds, outputs=[up_status, ws])
    start.click(
        start_training,
        outputs=[status, download_file, ws, logs, model_list]
    )
    refresh_dl_btn.click(
        refresh_download,
        outputs=[download_file, ws, model_list]
    )

    refresh_btn.click(lambda: gr.update(choices=list_models()), outputs=model_list)
    zip_in.change(import_zip, inputs=zip_in, outputs=[import_status, model_list])
    go.click(generate, inputs=[model_list, prompt], outputs=out)

app.launch()