msmaje's picture
Update app.py
e532a61 verified
import os
import argparse
import torch
import gradio as gr
import threading
import time
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType, PeftModel # Imported PeftModel for robust check
import random
# Set seed for reproducibility for consistent results
random.seed(42)
torch.manual_seed(42)
# Determine the device for model execution. Prioritize CUDA (GPU) if available, otherwise use CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Global variables to store the loaded model, tokenizer, and training-related data
model = None
tokenizer = None
training_stats = {}
trainer = None
train_dataset = None
eval_dataset = None
# --- Status Management (Thread-Safe) ---
# A class to manage and update the application's status, progress, and error messages
# in a thread-safe manner, as background operations will modify these.
class StatusManager:
def __init__(self):
self._lock = threading.Lock() # Ensures only one thread can modify status at a time
self.status = "Ready" # Current descriptive status message
self.progress = 0 # Progress percentage (0-100)
self.model_loaded = False # Boolean flag: Is the base model loaded?
self.model_trained = False # Boolean flag: Has the model completed training?
self.error = None # Stores any error message encountered
def update_status(self, status: str, progress: int = None, error: str = None):
"""Updates the current status, optional progress percentage, and optional error message."""
with self._lock: # Acquire lock before modifying shared state
self.status = status
if progress is not None:
self.progress = progress
if error is not None:
self.error = error
def set_model_loaded(self, loaded: bool):
"""Sets the flag indicating whether the model has been loaded."""
with self._lock:
self.model_loaded = loaded
def set_model_trained(self, trained: bool):
"""Sets the flag indicating whether the model has completed training."""
with self._lock:
self.model_trained = trained
def get_status(self):
"""Returns a dictionary containing the current status, progress, and flags."""
with self._lock: # Acquire lock before reading shared state
return {
'status': self.status,
'progress': self.progress,
'model_loaded': self.model_loaded,
'model_trained': self.model_trained,
'error': self.error
}
status_manager = StatusManager()
# --- Model Loading ---
def initialize_model_background():
"""
Loads the base pre-trained language model (distilgpt2) and its tokenizer
in a background thread to keep the Gradio UI responsive.
"""
global model, tokenizer
try:
status_manager.update_status("πŸ”„ Loading base distilgpt2 model...", 10)
# Clear CUDA cache if a GPU is available to free up memory before loading a new model
if torch.cuda.is_available():
torch.cuda.empty_cache()
status_manager.update_status("πŸ”„ Downloading model weights (this might take a while)...", 30)
# Changed model to distilgpt2 for lighter computation
model_name = "distilgpt2"
# Load the tokenizer associated with the model
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True, # Allow custom code in model config if necessary
)
# Ensure the tokenizer has a padding token, which is crucial for batch processing during training
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Use EOS token as pad token if not defined
status_manager.update_status("πŸ”„ Loading model into memory...", 50)
# Load the causal language model
model = AutoModelForCausalLM.from_pretrained(
model_name,
# Use float16 for GPU (half precision for faster computation, lower memory), float32 for CPU
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# 'device_map="auto"' intelligently distributes model layers across available GPUs.
# For CPU, it should be None, and then explicitly moved.
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
status_manager.update_status("πŸ”„ Moving model to target device (GPU/CPU)...", 90)
# If not using 'device_map="auto"' (i.e., on CPU), explicitly move the model to the target device
if not torch.cuda.is_available():
model = model.to(device)
status_manager.update_status("βœ… Model loaded successfully!", 100)
status_manager.set_model_loaded(True)
# Report the number of parameters to give an idea of model size
param_count = sum(p.numel() for p in model.parameters())
print(f"βœ… Model initialized! Parameters: {param_count/1e6:.2f}M") # Display in millions
except Exception as e:
error_msg = f"❌ Error loading model: {str(e)}"
status_manager.update_status("❌ Model loading failed", 0, error_msg)
print(error_msg)
def start_model_loading():
"""Initiates the model loading process in a background thread."""
if status_manager.get_status()['model_loaded']:
return "Model already loaded!" # Prevent loading multiple times
thread = threading.Thread(target=initialize_model_background, daemon=True)
thread.start()
return "πŸš€ Started loading model in background..."
# --- Model Preparation for Training (LoRA) ---
def prepare_model_for_training():
"""
Applies LoRA (Low-Rank Adaptation) adapters to the base model.
This makes the model more memory-efficient and faster to fine-tune.
"""
global model
state = status_manager.get_status()
if not state['model_loaded']:
return "❌ Please load the model first!"
if model is None:
return "❌ Model not available!"
try:
status_manager.update_status("πŸ”„ Configuring LoRA adapters...", 0)
# Check if LoRA adapters are already applied.
if isinstance(model, PeftModel):
status_manager.update_status("βœ… Model already prepared for training", 100)
return "βœ… Model already prepared for training"
# Define LoRA configuration. Target modules are specific to distilgpt2's architecture.
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, # LoRA attention dimension (e.g., 8, 16, 32)
lora_alpha=16, # Alpha parameter for LoRA scaling
lora_dropout=0.1, # Dropout probability for LoRA layers
bias="none", # Bias type (none, all, lora_only)
# Adjusted target modules for distilgpt2
target_modules=["c_attn", "c_proj", "c_fc"],
)
# Apply LoRA to the base model, making only a small portion trainable
model = get_peft_model(model, lora_config)
model.print_trainable_parameters() # Prints a summary of trainable vs. total parameters
status_manager.update_status("βœ… LoRA configuration applied!", 100)
return "βœ… LoRA configuration applied! Model ready for training."
except Exception as e:
error_msg = f"❌ Error preparing model for LoRA: {str(e)}"
status_manager.update_status("❌ Model preparation failed", 0, error_msg)
return error_msg
# --- Dataset Loading and Preprocessing ---
def format_chat_template(example):
"""
Formats a single conversation example from the dataset into a plain text string
that the language model can understand for training.
"""
if 'conversations' not in example or not example['conversations']:
return {"text": ""}
formatted_text = ""
# Optional: Add a system message at the beginning of each conversation
# formatted_text += "System: Kai likita ne mai hankali da Ζ™warewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n"
for turn in example['conversations']:
role = turn.get('from', turn.get('role', '')) # Handles different key names for roles
content = turn.get('value', turn.get('content', '')) # Handles different key names for content
if role.lower() == 'human' or role.lower() == 'user':
formatted_text += f"User: {content}\n"
elif role.lower() == 'gpt' or role.lower() == 'assistant':
formatted_text += f"Assistant: {content}\n"
# Append the EOS token to mark the end of a conversation for the model
formatted_text += tokenizer.eos_token
return {"text": formatted_text}
def load_dataset_background():
"""
Loads the Hausa medical conversations dataset from Hugging Face Hub and
preprocesses it for training, all in a background thread.
"""
global train_dataset, eval_dataset
try:
status_manager.update_status("πŸ”„ Loading Hausa medical dataset from Hugging Face Hub...", 10)
dataset_name = "ictbiortc/hausa-medical-conversations-format-9k"
dataset = load_dataset(dataset_name)
if dataset is None:
raise ValueError("Dataset not found or could not be loaded.")
status_manager.update_status("πŸ”„ Processing dataset (formatting conversations)...", 40)
# If the dataset doesn't explicitly have a 'test' split, create one
if 'test' not in dataset:
print("No 'test' split found in dataset, creating a 10% test split from 'train'.")
dataset = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_dataset_raw = dataset['train']
eval_dataset_raw = dataset['test']
else:
train_dataset_raw = dataset['train']
eval_dataset_raw = dataset['test']
# Apply the chat formatting function to both train and evaluation splits
train_dataset = train_dataset_raw.map(
format_chat_template,
remove_columns=train_dataset_raw.column_names, # Remove original columns to keep only 'text'
desc="Formatting train dataset"
)
eval_dataset = eval_dataset_raw.map(
format_chat_template,
remove_columns=eval_dataset_raw.column_names,
desc="Formatting eval dataset"
)
# Filter out any examples that resulted in empty text after formatting
train_dataset = train_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty train examples")
eval_dataset = eval_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty eval examples")
status_manager.update_status(f"βœ… Dataset loaded! Train samples: {len(train_dataset)}, Validation samples: {len(eval_dataset)}", 100)
print(f"Dataset loading complete: Train samples={len(train_dataset)}, Eval samples={len(eval_dataset)}")
except Exception as e:
error_msg = f"❌ Error loading or processing dataset: {str(e)}"
status_manager.update_status("❌ Dataset loading failed", 0, error_msg)
print(error_msg)
train_dataset = None
eval_dataset = None
# --- Custom Data Collator ---
# A data collator is necessary for dynamic padding of sequences within a batch,
# ensuring all sequences in a batch have the same length for efficient processing.
class DataCollator:
def __init__(self, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
def __call__(self, examples):
texts = [example['text'] for example in examples]
# Tokenize the batch of texts
tokenized = self.tokenizer(
texts,
truncation=True, # Truncate sequences longer than max_length
padding="max_length", # Pad to max_length for consistency within the batch
max_length=self.max_length,
return_tensors="pt" # Return PyTorch tensors
)
# For causal language modeling, the labels are typically the same as the input_ids
# (the model predicts the next token in the sequence).
tokenized['labels'] = tokenized['input_ids'].clone()
return tokenized
# --- Custom Training Progress Callback ---
class CustomProgressCallback(TrainerCallback):
"""
A custom callback for the Hugging Face Trainer. It updates the Gradio UI's
status display with real-time training progress, including steps, percentage,
and current loss.
"""
def __init__(self, status_manager_instance):
self.status_manager = status_manager_instance
self.last_logged_progress = -1
self.last_logged_loss = None
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
"""Called whenever the Trainer logs something (e.g., loss, learning rate)."""
if logs is None:
return
current_loss = logs.get('loss')
if current_loss is not None:
self.last_logged_loss = current_loss
# Update progress based on global step, but be mindful of initial 0%
if state.max_steps > 0:
progress = int((state.global_step / state.max_steps) * 100)
# Only update if progress has increased or if it's the very first log
if progress != self.last_logged_progress or state.global_step == 1:
loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else ""
self.status_manager.update_status(
f"πŸš€ Training... Step {state.global_step}/{state.max_steps}{loss_info}",
progress
)
self.last_logged_progress = progress
else:
# Fallback if max_steps isn't set yet or is 0
self.status_manager.update_status(f"πŸš€ Training... Step {state.global_step}", None)
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of each training step."""
# Ensure that `on_log` captures the loss, and `on_step_end` provides frequent updates.
# This will mainly ensure the progress bar updates even if loss isn't logged every single step.
if state.max_steps > 0:
progress = int((state.global_step / state.max_steps) * 100)
if progress > self.last_logged_progress or (state.global_step == 1 and self.last_logged_progress == -1):
loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else ""
self.status_manager.update_status(
f"πŸš€ Training... Step {state.global_step}/{state.max_steps}{loss_info}",
progress
)
self.last_logged_progress = progress
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the end of each training epoch."""
self.status_manager.update_status(f"βœ… Epoch {int(state.epoch)} completed.",
int((state.global_step / state.max_steps) * 100) if state.max_steps > 0 else None)
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Called at the very end of the training process."""
# The final status is set by train_model_background, this is just for intermediate clarity.
pass
# --- Model Training Function ---
def train_model_background(batch_size, grad_accum, epochs, lr):
"""
Manages the entire model training lifecycle in a background thread:
loading dataset (if needed), setting up the Trainer, and initiating training.
"""
global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset
# Disable PyTorch anomaly detection for faster training.
# Re-enable if in-place modification errors persist with the new model.
# torch.autograd.set_detect_anomaly(True)
# print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.")
print("PyTorch anomaly detection is DISABLED for faster training.")
try:
# Step 1: Ensure dataset is loaded and ready
if train_dataset is None or eval_dataset is None:
status_manager.update_status("πŸ”„ Loading dataset for training...", 5)
load_dataset_background() # Call the background dataset loader
if train_dataset is None or len(train_dataset) == 0:
error_msg = "❌ Training dataset is empty or failed to load - training cannot proceed."
status_manager.update_status("❌ Training failed", 0, error_msg)
return
if eval_dataset is None or len(eval_dataset) == 0:
print("Warning: Evaluation dataset is empty or failed to load. Evaluation during training will be skipped.")
status_manager.update_status("πŸ”„ Setting up Hugging Face Trainer...", 10)
print(f"Trainer setup: Batch size={batch_size}, Grad Accum={grad_accum}, Epochs={epochs}, LR={lr}")
# Step 2: Initialize Data Collator
data_collator = DataCollator(tokenizer)
print("Data collator initialized.")
# Step 3: Configure Training Arguments
training_args = TrainingArguments(
output_dir="./results", # Directory for saving checkpoints
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size, # Use same batch size for evaluation
gradient_accumulation_steps=grad_accum,
num_train_epochs=epochs,
learning_rate=lr,
warmup_steps=10, # Number of steps for linear learning rate warmup
logging_steps=1, # Log training progress every step (important for granular feedback)
save_steps=100, # Save checkpoint more frequently for CPU tests
eval_steps=100, # Run evaluation more frequently for CPU tests
save_total_limit=2, # Keep only the last 2 checkpoints to save disk space
remove_unused_columns=False, # Necessary when dataset columns don't directly match model inputs
dataloader_drop_last=True, # Drop the last incomplete batch for consistent batch sizes
report_to=None, # Disable reporting to external services like Weights & Biases
optim="adamw_torch", # AdamW optimizer (PyTorch implementation)
lr_scheduler_type="linear", # Linear learning rate decay
seed=42, # Random seed for reproducibility
# Pin memory for faster data transfer between CPU and GPU if GPU is present
dataloader_pin_memory=True if torch.cuda.is_available() else False,
# For CPU, smaller log/eval steps are useful for frequent feedback without much overhead
)
# Step 4: Initialize the Hugging Face Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
# Only pass eval_dataset if it's not empty, otherwise Trainer might error
eval_dataset=eval_dataset if len(eval_dataset) > 0 else None,
data_collator=data_collator,
tokenizer=tokenizer, # Providing tokenizer to Trainer for internal operations (e.g., logging)
)
# Add the custom progress callback to the trainer for UI updates
trainer.add_callback(CustomProgressCallback(status_manager))
status_manager.update_status("πŸš€ Starting training process...", 30)
print("Training initiated...")
# Step 5: Start the training loop
train_result = trainer.train()
# Step 6: Store final training results
training_stats = train_result
status_manager.set_model_trained(True)
# Calculate and display total training time
training_time = train_result.metrics.get('train_runtime', 0)
minutes = int(training_time // 60)
seconds = int(training_time % 60)
final_loss = train_result.training_loss if hasattr(train_result, 'training_loss') else train_result.metrics.get('train_loss', 'N/A')
success_msg = f"πŸŽ‰ Training completed! Final Loss: {final_loss:.4f}, Time: {minutes}m {seconds}s"
status_manager.update_status(success_msg, 100)
print(success_msg)
except Exception as e:
error_msg = f"❌ Training failed: {str(e)}"
status_manager.update_status("❌ Training failed", 0, error_msg)
print(error_msg)
def start_training(batch_size, grad_accum, epochs, lr):
"""Initiates the model training process in a dedicated background thread."""
state = status_manager.get_status()
if not state['model_loaded']:
return "❌ Please load the model first!"
if model is None:
return "❌ Model not available!"
# Start the training thread with the provided parameters
thread = threading.Thread(
target=train_model_background,
args=(batch_size, grad_accum, epochs, lr),
daemon=True # Daemon threads exit automatically when the main program exits
)
thread.start()
return "πŸš€ Started training in background..."
# --- Status Retrieval for UI ---
def get_current_status():
"""Retrieves the current application status from the StatusManager and formats it for Gradio display."""
state = status_manager.get_status()
status_text = f"""πŸ“Š **Current Status**: {state['status']}
πŸ“ˆ **Progress**: {state['progress']}%
πŸ€– **Model Loaded**: {'βœ…' if state['model_loaded'] else '❌'}
πŸŽ“ **Model Trained**: {'βœ…' if state['model_trained'] else '❌'}"""
if state['error']:
status_text += f"\n❌ **Error**: {state['error']}"
return status_text
# --- Chat Functionality (Inference) ---
def chat_with_model(message, history, temperature=1.0, max_tokens=200):
"""
Generates a conversational response from the loaded model based on the user's message
and the ongoing chat history.
"""
global model, tokenizer
state = status_manager.get_status()
if not state['model_loaded'] or model is None:
# If model is not loaded, return an error message to the user
return history + [{"role": "user", "content": message},
{"role": "assistant", "content": "❌ Please load the model first!"}]
# Build the conversation history into a single string, respecting roles
conversation = ""
# Optional: You can prepend a system message here for inference if your training didn't include it.
# system_message = "Kai likita ne mai hankali da Ζ™warewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n"
# conversation += system_message
for msg_pair in history: # history is a list of {"role": ..., "content": ...} dictionaries
if msg_pair["role"] == "user":
conversation += f"User: {msg_pair['content']}\n"
else: # "assistant"
conversation += f"Assistant: {msg_pair['content']}\n"
# Append the current user's message and prompt the assistant for a response
conversation += f"User: {message}\nAssistant:"
try:
# Tokenize the entire conversation string for model input
inputs = tokenizer(conversation, return_tensors="pt").to(device)
with torch.no_grad(): # Disable gradient calculations during inference for speed and memory efficiency
# Generate the model's response
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens, # Maximum number of tokens to generate for the response
temperature=temperature, # Controls randomness: higher = more creative, lower = more deterministic
top_p=0.95, # Nucleus sampling: sample from top P probability mass
top_k=50, # Top-k sampling: sample from top K most probable tokens
do_sample=True, # Enable sampling (otherwise uses greedy decoding)
repetition_penalty=1.1, # Penalizes repeated tokens to avoid repetitive responses
pad_token_id=tokenizer.pad_token_id, # Padding token ID
eos_token_id=tokenizer.eos_token_id, # End-of-sequence token ID to stop generation
)
# Decode the generated sequence back into human-readable text
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's actual response by removing the input prompt part
assistant_response = full_response[len(conversation):].strip()
# --- Post-processing for cleaner responses ---
# Remove any leading "Assistant:" if the model generated it as part of the response
if assistant_response.startswith("Assistant:"):
assistant_response = assistant_response[len("Assistant:"):].strip()
# Truncate response if the model starts generating a new 'User:' or 'System:' turn
if "User:" in assistant_response:
assistant_response = assistant_response.split("User:")[0].strip()
if "System:" in assistant_response:
assistant_response = assistant_response.split("System:")[0].strip()
# Return the updated chat history including the new user message and assistant's response
return history + [{"role": "user", "content": message},
{"role": "assistant", "content": assistant_response}]
except Exception as e:
error_msg = f"❌ Error during chat: {str(e)}"
print(error_msg)
return history + [{"role": "user", "content": message},
{"role": "assistant", "content": error_msg}]
# --- Model Saving ---
def save_model(output_path):
"""Saves the fine-tuned model and its tokenizer to a specified local directory."""
global model, tokenizer
state = status_manager.get_status()
if not state['model_trained']:
return "❌ Please complete training first before saving!"
try:
# Save the PEFT model and the tokenizer
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
return f"βœ… Model saved to {output_path}!"
except Exception as e:
return f"❌ Error saving model: {str(e)}"
# --- Sample Queries for UI ---
SAMPLE_QUERIES = [
"Ina jin ciwon kai da zazzabi tun kwana biyu. Me ya kamata in yi?",
"Dana yana da gudawa sosai. Ina bukatan taimako.",
"Yaya ake hana malaria lokacin damina?",
"Ina da ciwon sukari. Wanne abinci ya dace da ni?",
]
# --- Gradio Interface Definition ---
with gr.Blocks(title="Hausa Health Assistant", theme=gr.themes.Soft()) as app:
gr.Markdown("# πŸ₯ Hausa Health Assistant")
gr.Markdown("Train and test an AI health assistant in Hausa language")
# Display for current application status
status_display = gr.Markdown("πŸ“Š **Current Status**: Ready")
with gr.Tabs():
# --- Model Management Tab ---
with gr.TabItem("πŸ€– Model Management"):
with gr.Row():
with gr.Column():
load_btn = gr.Button("πŸš€ Load Base Model", variant="primary")
prep_btn = gr.Button("βš™οΈ Prepare for Training (LoRA)")
gr.Markdown("### Training Parameters")
# Adjusted default parameters for more feasible CPU testing
batch_size = gr.Slider(1, 4, 1, step=1, label="Batch Size", info="Per device batch size for training. Start with 1 on CPU to avoid OOM.")
grad_accum = gr.Slider(1, 8, 1, step=1, label="Gradient Accumulation", info="Number of updates steps to accumulate gradients for. Start with 1 for debugging.")
epochs = gr.Slider(1, 3, 1, step=1, label="Epochs", info="Number of training epochs. Start with 1 for initial tests.")
learning_rate = gr.Slider(1e-5, 5e-4, 2e-4, label="Learning Rate", info="Initial learning rate for the optimizer.")
train_btn = gr.Button("🎯 Start Training", variant="primary")
gr.Markdown("### Save Model")
save_path = gr.Textbox(value="hausa-health-assistant-finetuned", label="Save Path", info="Directory to save the fine-tuned model and tokenizer.")
save_btn = gr.Button("πŸ’Ύ Save Model")
with gr.Column():
operation_status = gr.Textbox(label="Operation Status Log", lines=3, interactive=False, value="Awaiting operations...")
# --- Chat Tab ---
with gr.TabItem("πŸ’¬ Chat"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Sample Queries")
sample_btns = [] # Store references to sample query buttons
for query_text in SAMPLE_QUERIES:
btn = gr.Button(f"{query_text[:40]}...", size="sm")
sample_btns.append((btn, query_text))
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Health Assistant", height=400, type="messages", layout="bubble")
msg = gr.Textbox(label="Your Message (Hausa)", placeholder="Ka rubuta tambayarka...")
with gr.Row():
temp = gr.Slider(0.1, 2.0, 1.0, label="Temperature", info="Controls the randomness of the model's output (higher = more creative).")
tokens = gr.Slider(50, 500, 200, label="Max Tokens", info="Maximum number of new tokens to generate in the response.")
with gr.Row():
send_btn = gr.Button("πŸ“€ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat")
# --- Event Handlers for Gradio Components ---
load_btn.click(start_model_loading, outputs=[operation_status])
prep_btn.click(prepare_model_for_training, outputs=[operation_status])
train_btn.click(
start_training,
inputs=[batch_size, grad_accum, epochs, learning_rate],
outputs=[operation_status]
)
save_btn.click(save_model, inputs=[save_path], outputs=[operation_status])
# Chat interaction
send_btn.click(
chat_with_model,
inputs=[msg, chatbot, temp, tokens],
outputs=[chatbot]
).then(lambda: gr.Textbox(value="", interactive=True), outputs=[msg]) # Clear input after sending
clear_btn.click(lambda: [], outputs=[chatbot]) # Clear chat history
# Attach click handlers to dynamically created sample query buttons
for btn, query in sample_btns:
# Use a lambda with a default argument to capture the current query value
btn.click(lambda q=query: gr.Textbox(value=q, interactive=True), inputs=[], outputs=[msg])
# Auto-update status display every 2 seconds using a Gradio Timer
status_timer = gr.Timer(value=2)
status_timer.tick(get_current_status, outputs=[status_display])
# --- Main Application Entry Point ---
def main():
"""Main function to parse command-line arguments and launch the Gradio application."""
parser = argparse.ArgumentParser(description="Hausa Health Assistant Training App")
# Argument to enable Gradio's shareable link feature (defaulting to False)
parser.add_argument("--share", action="store_true", default=False, help="Create a shareable link for the Gradio app.")
args = parser.parse_args()
# Clear CUDA cache at startup if a GPU is available
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Set environment variable to suppress tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print(f"Starting Hausa Health Assistant app...")
# Launch the Gradio app, allowing access from any IP (0.0.0.0) and a specific port.
# 'share=args.share' will create a publicly shareable link if the --share flag is used.
app.launch(server_name="0.0.0.0", server_port=7860, share=args.share)
if __name__ == "__main__":
main()