|
|
import tkinter as tk |
|
|
from tkinter import ttk, filedialog, scrolledtext |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg |
|
|
import threading |
|
|
import subprocess |
|
|
import numpy as np |
|
|
|
|
|
class MultiModalGUI(tk.Tk): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.title("Multimodal Transformer Trainer") |
|
|
self.geometry("1200x800") |
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.hparams = {"lr": 1e-4, "batch_size": 32} |
|
|
|
|
|
|
|
|
self.notebook = ttk.Notebook(self) |
|
|
self.training_tab = TrainingTab(self.notebook, self) |
|
|
self.inference_tab = InferenceTab(self.notebook, self) |
|
|
self.data_tab = DataTab(self.notebook, self) |
|
|
self.settings_tab = SettingsTab(self.notebook, self) |
|
|
self.notebook.add(self.training_tab, text="Training") |
|
|
self.notebook.add(self.inference_tab, text="Inference") |
|
|
self.notebook.add(self.data_tab, text="Data") |
|
|
self.notebook.add(self.settings_tab, text="Settings") |
|
|
self.notebook.pack(expand=True, fill="both") |
|
|
|
|
|
class TrainingTab(ttk.Frame): |
|
|
def __init__(self, parent, controller): |
|
|
super().__init__(parent) |
|
|
self.controller = controller |
|
|
|
|
|
|
|
|
self.figure = plt.figure(figsize=(8, 4)) |
|
|
self.ax = self.figure.add_subplot(111) |
|
|
self.canvas = FigureCanvasTkAgg(self.figure, self) |
|
|
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) |
|
|
|
|
|
|
|
|
self.start_btn = ttk.Button(self, text="Start Training", command=self.start_training) |
|
|
self.start_btn.pack(side=tk.LEFT) |
|
|
self.stop_btn = ttk.Button(self, text="Stop", command=self.stop_training) |
|
|
self.stop_btn.pack(side=tk.LEFT) |
|
|
|
|
|
|
|
|
self.logs = scrolledtext.ScrolledText(self, height=10) |
|
|
self.logs.pack(fill=tk.BOTH) |
|
|
|
|
|
def start_training(self): |
|
|
def train(): |
|
|
|
|
|
losses, accs = [], [] |
|
|
for epoch in range(10): |
|
|
loss = np.random.rand() |
|
|
acc = np.random.rand() |
|
|
losses.append(loss) |
|
|
accs.append(acc) |
|
|
self.update_plot(losses, accs) |
|
|
self.log(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}") |
|
|
threading.Thread(target=train).start() |
|
|
|
|
|
def update_plot(self, losses, accs): |
|
|
self.ax.clear() |
|
|
self.ax.plot(losses, label="Loss") |
|
|
self.ax.plot(accs, label="Accuracy") |
|
|
self.ax.legend() |
|
|
self.canvas.draw() |
|
|
|
|
|
def log(self, message): |
|
|
self.logs.insert(tk.END, message + "\n") |
|
|
self.logs.see(tk.END) |
|
|
|
|
|
def stop_training(self): |
|
|
pass |
|
|
|
|
|
class InferenceTab(ttk.Frame): |
|
|
def __init__(self, parent, controller): |
|
|
super().__init__(parent) |
|
|
self.controller = controller |
|
|
|
|
|
|
|
|
self.task_var = tk.StringVar() |
|
|
self.task_combo = ttk.Combobox(self, values=["text_generation", "image_captioning", "music_generation"]) |
|
|
self.task_combo.pack() |
|
|
|
|
|
|
|
|
self.input_text = scrolledtext.ScrolledText(self, height=5) |
|
|
self.input_text.pack() |
|
|
self.output_text = scrolledtext.ScrolledText(self, height=10) |
|
|
self.output_text.pack() |
|
|
self.run_btn = ttk.Button(self, text="Run Inference", command=self.run_inference) |
|
|
self.run_btn.pack() |
|
|
|
|
|
def run_inference(self): |
|
|
task = self.task_combo.get() |
|
|
input_data = self.input_text.get("1.0", tk.END) |
|
|
|
|
|
output = f"Model output for {task}: {input_data[:20]}..." |
|
|
self.output_text.insert(tk.END, output + "\n") |
|
|
|
|
|
class DataTab(ttk.Frame): |
|
|
def __init__(self, parent, controller): |
|
|
super().__init__(parent) |
|
|
self.controller = controller |
|
|
|
|
|
|
|
|
self.upload_btn = ttk.Button(self, text="Upload Dataset", command=self.upload_dataset) |
|
|
self.upload_btn.pack() |
|
|
|
|
|
|
|
|
self.preview_text = scrolledtext.ScrolledText(self, height=15) |
|
|
self.preview_text.pack() |
|
|
|
|
|
def upload_dataset(self): |
|
|
filetypes = [("Text", "*.txt"), ("CSV", "*.csv"), ("Images", "*.jpg")] |
|
|
filenames = filedialog.askopenfilenames(filetypes=filetypes) |
|
|
self.preview_text.insert(tk.END, f"Uploaded: {filenames}\n") |
|
|
|
|
|
class SettingsTab(ttk.Frame): |
|
|
def __init__(self, parent, controller): |
|
|
super().__init__(parent) |
|
|
self.controller = controller |
|
|
|
|
|
|
|
|
ttk.Label(self, text="Learning Rate:").pack() |
|
|
self.lr_entry = ttk.Entry(self) |
|
|
self.lr_entry.insert(0, "0.001") |
|
|
self.lr_entry.pack() |
|
|
|
|
|
ttk.Label(self, text="Batch Size:").pack() |
|
|
self.batch_entry = ttk.Entry(self) |
|
|
self.batch_entry.insert(0, "32") |
|
|
self.batch_entry.pack() |
|
|
|
|
|
|
|
|
self.save_btn = ttk.Button(self, text="Save Model", command=self.save_model) |
|
|
self.save_btn.pack() |
|
|
self.load_btn = ttk.Button(self, text="Load Model", command=self.load_model) |
|
|
self.load_btn.pack() |
|
|
|
|
|
def save_model(self): |
|
|
filename = filedialog.asksaveasfilename(defaultextension=".h5") |
|
|
if filename: |
|
|
self.controller.model.save_weights(filename) |
|
|
|
|
|
def load_model(self): |
|
|
filename = filedialog.askopenfilename(filetypes=[("HDF5", "*.h5")]) |
|
|
if filename: |
|
|
self.controller.model.load_weights(filename) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = MultiModalGUI() |
|
|
app.mainloop() |