gf-ui / app.py
rsalehin's picture
Update app.py
d429163 verified
import gradio as gr
import pandas as pd
import numpy as np
import time
# Custom CSS inspired by Google's Material Design
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&display=swap');
body, .gradio-container {
font-family: 'Roboto', sans-serif;
background-color: #FAFAFA;
color: #202124;
}
h1, h2, h3, h4, .title {
font-weight: 400;
}
.gr-button {
background-color: #1a73e8 !important;
color: #fff !important;
border-radius: 4px !important;
font-weight: 500;
padding: 0.6em 1.2em;
transition: background-color 0.3s ease;
}
.gr-button:hover {
background-color: #1669c1 !important;
}
.gr-input, .gr-textbox, .gr-file, .gr-slider input {
border: 1px solid #dadce0 !important;
border-radius: 4px !important;
padding: 0.5em;
}
.gr-slider > div {
background-color: #1a73e8 !important;
}
.tab-item {
padding: 1em;
}
"""
# ---------------------------------
# Data Upload and Preprocessing Module
# ---------------------------------
def process_data(file, augment):
"""
Validates and preprocesses the uploaded file.
For CSV files: reads the CSV, shows a preview,
and if augmentation is selected, applies simple augmentation.
For JSONL/TXT: just displays a preview.
"""
if file is None:
return "No file uploaded yet."
name = file.name
ext = name.split('.')[-1].lower()
if ext == "csv":
try:
df = pd.read_csv(file.name)
except Exception as e:
return f"Error reading CSV: {e}"
original_preview = df.head().to_html(classes="dataframe", border=0)
result = f"<b>Original Data (Preview):</b><br>{original_preview}"
if augment:
# Simple augmentation: add random noise to numeric columns
df_aug = df.copy()
num_cols = df_aug.select_dtypes(include=[np.number]).columns
if len(num_cols) > 0:
noise = np.random.normal(0, 0.05, df_aug[num_cols].shape)
df_aug[num_cols] = df_aug[num_cols] + noise
aug_preview = df_aug.head().to_html(classes="dataframe", border=0)
result += f"<br><br><b>Augmented Data (Preview):</b><br>{aug_preview}"
else:
result += "<br><br><b>Note:</b> No numeric columns found for augmentation."
return result
elif ext == "jsonl":
try:
with open(file.name, "r") as f:
lines = f.readlines()
preview = "".join(lines[:5])
return f"<b>File:</b> {name}<br><br><b>Preview:</b><br>{preview}"
except Exception as e:
return f"Error reading JSONL file: {e}"
elif ext == "txt":
try:
with open(file.name, "r") as f:
content = f.read(500)
return f"<b>File:</b> {name}<br><br><b>Preview (first 500 characters):</b><br>{content}"
except Exception as e:
return f"Error reading TXT file: {e}"
else:
return "Unsupported file type. Please upload a CSV, JSONL, or TXT file."
data_upload_interface = gr.Interface(
fn=process_data,
inputs=[
gr.File(label="Upload CSV/JSONL/TXT File"),
gr.Checkbox(label="Apply Data Augmentation", value=False)
],
outputs=gr.HTML(),
title="Data Upload & Preprocessing",
description="Upload your dataset file, validate its format, and optionally apply data augmentation."
)
# ---------------------------------
# Hyperparameter Configuration Module
# ---------------------------------
def configure_hyperparameters(learning_rate, batch_size, epochs):
config = f"<b>Learning Rate:</b> {learning_rate}<br>" + \
f"<b>Batch Size:</b> {batch_size}<br>" + \
f"<b>Epochs:</b> {epochs}"
return config
hyperparameter_interface = gr.Interface(
fn=configure_hyperparameters,
inputs=[
gr.Slider(0.0001, 0.1, value=0.001, label="Learning Rate", step=0.0001),
gr.Dropdown(choices=["16", "32", "64", "128"], value="32", label="Batch Size"),
gr.Number(value=10, label="Epochs")
],
outputs=gr.HTML(),
title="Hyperparameter Settings",
description="Adjust the training parameters for fine-tuning the model."
)
# ---------------------------------
# Training Dashboard Module (Simulation)
# ---------------------------------
def simulate_training():
progress_vals = []
loss_vals = []
for i in range(1, 101):
time.sleep(0.03) # Simulate training iteration delay
progress_vals.append(i)
loss_vals.append(np.random.rand() + (100-i)/100) # Simulated loss curve
sample_output = "This is a generated snippet from the fine-tuned Gemma model."
return progress_vals, loss_vals, sample_output
training_interface = gr.Interface(
fn=simulate_training,
inputs=[],
outputs=[
gr.Plot(label="Training Progress (%)"),
gr.Plot(label="Loss Curve"),
gr.Textbox(label="Sample Output", lines=3)
],
title="Training Dashboard",
description="Monitor training progress in real-time (simulation)."
)
# ---------------------------------
# Model Export Module
# ---------------------------------
def export_model(export_format):
time.sleep(2) # Simulate export process
return f"Model exported as <b>{export_format}</b>! Download link: <a href='#'>[dummy_link]</a>"
export_interface = gr.Interface(
fn=export_model,
inputs=gr.Radio(["TensorFlow SavedModel", "PyTorch", "GGUF"], label="Select Export Format"),
outputs=gr.HTML(),
title="Model Export",
description="Export your fine-tuned model in the desired format."
)
# ---------------------------------
# Help & Documentation Module
# ---------------------------------
help_text = """
### Getting Started with Gemma Fine-tuning UI
1. **Data Upload & Preprocessing:**
Upload your dataset in CSV, JSONL, or TXT format. The app validates your file and shows a preview.
Optionally, enable data augmentation (e.g., adding random noise to numeric columns).
2. **Hyperparameter Settings:**
Configure training parameters such as learning rate, batch size, and epochs.
3. **Training Dashboard:**
Monitor the training progress in real-time. This demo simulates a training session.
4. **Model Export:**
Export your fine-tuned model in a variety of formats.
For more detailed documentation, please refer to the [official documentation](https://example.com).
"""
help_interface = gr.Interface(
fn=lambda: help_text,
inputs=[],
outputs="markdown",
title="Help & Documentation"
)
# ---------------------------------
# Assemble the Tabbed Interface with Gradio Blocks
# ---------------------------------
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("<h1 style='text-align: center;'>Gemma Fine-tuning UI</h1>")
with gr.Tabs():
with gr.TabItem("Data Upload & Preprocessing"):
data_upload_interface.render()
with gr.TabItem("Hyperparameters"):
hyperparameter_interface.render()
with gr.TabItem("Training"):
training_interface.render()
with gr.TabItem("Model Export"):
export_interface.render()
with gr.TabItem("Help"):
help_interface.render()
demo.launch(share=True)