Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import torch | |
| import gradio as gr | |
| import threading | |
| import time | |
| from datetime import datetime | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, TrainerState, TrainerControl | |
| from datasets import load_dataset | |
| from peft import get_peft_model, LoraConfig, TaskType, PeftModel # Imported PeftModel for robust check | |
| import random | |
| # Set seed for reproducibility for consistent results | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| # Determine the device for model execution. Prioritize CUDA (GPU) if available, otherwise use CPU. | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Global variables to store the loaded model, tokenizer, and training-related data | |
| model = None | |
| tokenizer = None | |
| training_stats = {} | |
| trainer = None | |
| train_dataset = None | |
| eval_dataset = None | |
| # --- Status Management (Thread-Safe) --- | |
| # A class to manage and update the application's status, progress, and error messages | |
| # in a thread-safe manner, as background operations will modify these. | |
| class StatusManager: | |
| def __init__(self): | |
| self._lock = threading.Lock() # Ensures only one thread can modify status at a time | |
| self.status = "Ready" # Current descriptive status message | |
| self.progress = 0 # Progress percentage (0-100) | |
| self.model_loaded = False # Boolean flag: Is the base model loaded? | |
| self.model_trained = False # Boolean flag: Has the model completed training? | |
| self.error = None # Stores any error message encountered | |
| def update_status(self, status: str, progress: int = None, error: str = None): | |
| """Updates the current status, optional progress percentage, and optional error message.""" | |
| with self._lock: # Acquire lock before modifying shared state | |
| self.status = status | |
| if progress is not None: | |
| self.progress = progress | |
| if error is not None: | |
| self.error = error | |
| def set_model_loaded(self, loaded: bool): | |
| """Sets the flag indicating whether the model has been loaded.""" | |
| with self._lock: | |
| self.model_loaded = loaded | |
| def set_model_trained(self, trained: bool): | |
| """Sets the flag indicating whether the model has completed training.""" | |
| with self._lock: | |
| self.model_trained = trained | |
| def get_status(self): | |
| """Returns a dictionary containing the current status, progress, and flags.""" | |
| with self._lock: # Acquire lock before reading shared state | |
| return { | |
| 'status': self.status, | |
| 'progress': self.progress, | |
| 'model_loaded': self.model_loaded, | |
| 'model_trained': self.model_trained, | |
| 'error': self.error | |
| } | |
| status_manager = StatusManager() | |
| # --- Model Loading --- | |
| def initialize_model_background(): | |
| """ | |
| Loads the base pre-trained language model (distilgpt2) and its tokenizer | |
| in a background thread to keep the Gradio UI responsive. | |
| """ | |
| global model, tokenizer | |
| try: | |
| status_manager.update_status("π Loading base distilgpt2 model...", 10) | |
| # Clear CUDA cache if a GPU is available to free up memory before loading a new model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| status_manager.update_status("π Downloading model weights (this might take a while)...", 30) | |
| # Changed model to distilgpt2 for lighter computation | |
| model_name = "distilgpt2" | |
| # Load the tokenizer associated with the model | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, # Allow custom code in model config if necessary | |
| ) | |
| # Ensure the tokenizer has a padding token, which is crucial for batch processing during training | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token # Use EOS token as pad token if not defined | |
| status_manager.update_status("π Loading model into memory...", 50) | |
| # Load the causal language model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| # Use float16 for GPU (half precision for faster computation, lower memory), float32 for CPU | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| # 'device_map="auto"' intelligently distributes model layers across available GPUs. | |
| # For CPU, it should be None, and then explicitly moved. | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| ) | |
| status_manager.update_status("π Moving model to target device (GPU/CPU)...", 90) | |
| # If not using 'device_map="auto"' (i.e., on CPU), explicitly move the model to the target device | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| status_manager.update_status("β Model loaded successfully!", 100) | |
| status_manager.set_model_loaded(True) | |
| # Report the number of parameters to give an idea of model size | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| print(f"β Model initialized! Parameters: {param_count/1e6:.2f}M") # Display in millions | |
| except Exception as e: | |
| error_msg = f"β Error loading model: {str(e)}" | |
| status_manager.update_status("β Model loading failed", 0, error_msg) | |
| print(error_msg) | |
| def start_model_loading(): | |
| """Initiates the model loading process in a background thread.""" | |
| if status_manager.get_status()['model_loaded']: | |
| return "Model already loaded!" # Prevent loading multiple times | |
| thread = threading.Thread(target=initialize_model_background, daemon=True) | |
| thread.start() | |
| return "π Started loading model in background..." | |
| # --- Model Preparation for Training (LoRA) --- | |
| def prepare_model_for_training(): | |
| """ | |
| Applies LoRA (Low-Rank Adaptation) adapters to the base model. | |
| This makes the model more memory-efficient and faster to fine-tune. | |
| """ | |
| global model | |
| state = status_manager.get_status() | |
| if not state['model_loaded']: | |
| return "β Please load the model first!" | |
| if model is None: | |
| return "β Model not available!" | |
| try: | |
| status_manager.update_status("π Configuring LoRA adapters...", 0) | |
| # Check if LoRA adapters are already applied. | |
| if isinstance(model, PeftModel): | |
| status_manager.update_status("β Model already prepared for training", 100) | |
| return "β Model already prepared for training" | |
| # Define LoRA configuration. Target modules are specific to distilgpt2's architecture. | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=8, # LoRA attention dimension (e.g., 8, 16, 32) | |
| lora_alpha=16, # Alpha parameter for LoRA scaling | |
| lora_dropout=0.1, # Dropout probability for LoRA layers | |
| bias="none", # Bias type (none, all, lora_only) | |
| # Adjusted target modules for distilgpt2 | |
| target_modules=["c_attn", "c_proj", "c_fc"], | |
| ) | |
| # Apply LoRA to the base model, making only a small portion trainable | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() # Prints a summary of trainable vs. total parameters | |
| status_manager.update_status("β LoRA configuration applied!", 100) | |
| return "β LoRA configuration applied! Model ready for training." | |
| except Exception as e: | |
| error_msg = f"β Error preparing model for LoRA: {str(e)}" | |
| status_manager.update_status("β Model preparation failed", 0, error_msg) | |
| return error_msg | |
| # --- Dataset Loading and Preprocessing --- | |
| def format_chat_template(example): | |
| """ | |
| Formats a single conversation example from the dataset into a plain text string | |
| that the language model can understand for training. | |
| """ | |
| if 'conversations' not in example or not example['conversations']: | |
| return {"text": ""} | |
| formatted_text = "" | |
| # Optional: Add a system message at the beginning of each conversation | |
| # formatted_text += "System: Kai likita ne mai hankali da Ζwarewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n" | |
| for turn in example['conversations']: | |
| role = turn.get('from', turn.get('role', '')) # Handles different key names for roles | |
| content = turn.get('value', turn.get('content', '')) # Handles different key names for content | |
| if role.lower() == 'human' or role.lower() == 'user': | |
| formatted_text += f"User: {content}\n" | |
| elif role.lower() == 'gpt' or role.lower() == 'assistant': | |
| formatted_text += f"Assistant: {content}\n" | |
| # Append the EOS token to mark the end of a conversation for the model | |
| formatted_text += tokenizer.eos_token | |
| return {"text": formatted_text} | |
| def load_dataset_background(): | |
| """ | |
| Loads the Hausa medical conversations dataset from Hugging Face Hub and | |
| preprocesses it for training, all in a background thread. | |
| """ | |
| global train_dataset, eval_dataset | |
| try: | |
| status_manager.update_status("π Loading Hausa medical dataset from Hugging Face Hub...", 10) | |
| dataset_name = "ictbiortc/hausa-medical-conversations-format-9k" | |
| dataset = load_dataset(dataset_name) | |
| if dataset is None: | |
| raise ValueError("Dataset not found or could not be loaded.") | |
| status_manager.update_status("π Processing dataset (formatting conversations)...", 40) | |
| # If the dataset doesn't explicitly have a 'test' split, create one | |
| if 'test' not in dataset: | |
| print("No 'test' split found in dataset, creating a 10% test split from 'train'.") | |
| dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) | |
| train_dataset_raw = dataset['train'] | |
| eval_dataset_raw = dataset['test'] | |
| else: | |
| train_dataset_raw = dataset['train'] | |
| eval_dataset_raw = dataset['test'] | |
| # Apply the chat formatting function to both train and evaluation splits | |
| train_dataset = train_dataset_raw.map( | |
| format_chat_template, | |
| remove_columns=train_dataset_raw.column_names, # Remove original columns to keep only 'text' | |
| desc="Formatting train dataset" | |
| ) | |
| eval_dataset = eval_dataset_raw.map( | |
| format_chat_template, | |
| remove_columns=eval_dataset_raw.column_names, | |
| desc="Formatting eval dataset" | |
| ) | |
| # Filter out any examples that resulted in empty text after formatting | |
| train_dataset = train_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty train examples") | |
| eval_dataset = eval_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty eval examples") | |
| status_manager.update_status(f"β Dataset loaded! Train samples: {len(train_dataset)}, Validation samples: {len(eval_dataset)}", 100) | |
| print(f"Dataset loading complete: Train samples={len(train_dataset)}, Eval samples={len(eval_dataset)}") | |
| except Exception as e: | |
| error_msg = f"β Error loading or processing dataset: {str(e)}" | |
| status_manager.update_status("β Dataset loading failed", 0, error_msg) | |
| print(error_msg) | |
| train_dataset = None | |
| eval_dataset = None | |
| # --- Custom Data Collator --- | |
| # A data collator is necessary for dynamic padding of sequences within a batch, | |
| # ensuring all sequences in a batch have the same length for efficient processing. | |
| class DataCollator: | |
| def __init__(self, tokenizer, max_length=512): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __call__(self, examples): | |
| texts = [example['text'] for example in examples] | |
| # Tokenize the batch of texts | |
| tokenized = self.tokenizer( | |
| texts, | |
| truncation=True, # Truncate sequences longer than max_length | |
| padding="max_length", # Pad to max_length for consistency within the batch | |
| max_length=self.max_length, | |
| return_tensors="pt" # Return PyTorch tensors | |
| ) | |
| # For causal language modeling, the labels are typically the same as the input_ids | |
| # (the model predicts the next token in the sequence). | |
| tokenized['labels'] = tokenized['input_ids'].clone() | |
| return tokenized | |
| # --- Custom Training Progress Callback --- | |
| class CustomProgressCallback(TrainerCallback): | |
| """ | |
| A custom callback for the Hugging Face Trainer. It updates the Gradio UI's | |
| status display with real-time training progress, including steps, percentage, | |
| and current loss. | |
| """ | |
| def __init__(self, status_manager_instance): | |
| self.status_manager = status_manager_instance | |
| self.last_logged_progress = -1 | |
| self.last_logged_loss = None | |
| def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): | |
| """Called whenever the Trainer logs something (e.g., loss, learning rate).""" | |
| if logs is None: | |
| return | |
| current_loss = logs.get('loss') | |
| if current_loss is not None: | |
| self.last_logged_loss = current_loss | |
| # Update progress based on global step, but be mindful of initial 0% | |
| if state.max_steps > 0: | |
| progress = int((state.global_step / state.max_steps) * 100) | |
| # Only update if progress has increased or if it's the very first log | |
| if progress != self.last_logged_progress or state.global_step == 1: | |
| loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else "" | |
| self.status_manager.update_status( | |
| f"π Training... Step {state.global_step}/{state.max_steps}{loss_info}", | |
| progress | |
| ) | |
| self.last_logged_progress = progress | |
| else: | |
| # Fallback if max_steps isn't set yet or is 0 | |
| self.status_manager.update_status(f"π Training... Step {state.global_step}", None) | |
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| """Called at the end of each training step.""" | |
| # Ensure that `on_log` captures the loss, and `on_step_end` provides frequent updates. | |
| # This will mainly ensure the progress bar updates even if loss isn't logged every single step. | |
| if state.max_steps > 0: | |
| progress = int((state.global_step / state.max_steps) * 100) | |
| if progress > self.last_logged_progress or (state.global_step == 1 and self.last_logged_progress == -1): | |
| loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else "" | |
| self.status_manager.update_status( | |
| f"π Training... Step {state.global_step}/{state.max_steps}{loss_info}", | |
| progress | |
| ) | |
| self.last_logged_progress = progress | |
| def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| """Called at the end of each training epoch.""" | |
| self.status_manager.update_status(f"β Epoch {int(state.epoch)} completed.", | |
| int((state.global_step / state.max_steps) * 100) if state.max_steps > 0 else None) | |
| def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| """Called at the very end of the training process.""" | |
| # The final status is set by train_model_background, this is just for intermediate clarity. | |
| pass | |
| # --- Model Training Function --- | |
| def train_model_background(batch_size, grad_accum, epochs, lr): | |
| """ | |
| Manages the entire model training lifecycle in a background thread: | |
| loading dataset (if needed), setting up the Trainer, and initiating training. | |
| """ | |
| global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset | |
| # Disable PyTorch anomaly detection for faster training. | |
| # Re-enable if in-place modification errors persist with the new model. | |
| # torch.autograd.set_detect_anomaly(True) | |
| # print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.") | |
| print("PyTorch anomaly detection is DISABLED for faster training.") | |
| try: | |
| # Step 1: Ensure dataset is loaded and ready | |
| if train_dataset is None or eval_dataset is None: | |
| status_manager.update_status("π Loading dataset for training...", 5) | |
| load_dataset_background() # Call the background dataset loader | |
| if train_dataset is None or len(train_dataset) == 0: | |
| error_msg = "β Training dataset is empty or failed to load - training cannot proceed." | |
| status_manager.update_status("β Training failed", 0, error_msg) | |
| return | |
| if eval_dataset is None or len(eval_dataset) == 0: | |
| print("Warning: Evaluation dataset is empty or failed to load. Evaluation during training will be skipped.") | |
| status_manager.update_status("π Setting up Hugging Face Trainer...", 10) | |
| print(f"Trainer setup: Batch size={batch_size}, Grad Accum={grad_accum}, Epochs={epochs}, LR={lr}") | |
| # Step 2: Initialize Data Collator | |
| data_collator = DataCollator(tokenizer) | |
| print("Data collator initialized.") | |
| # Step 3: Configure Training Arguments | |
| training_args = TrainingArguments( | |
| output_dir="./results", # Directory for saving checkpoints | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, # Use same batch size for evaluation | |
| gradient_accumulation_steps=grad_accum, | |
| num_train_epochs=epochs, | |
| learning_rate=lr, | |
| warmup_steps=10, # Number of steps for linear learning rate warmup | |
| logging_steps=1, # Log training progress every step (important for granular feedback) | |
| save_steps=100, # Save checkpoint more frequently for CPU tests | |
| eval_steps=100, # Run evaluation more frequently for CPU tests | |
| save_total_limit=2, # Keep only the last 2 checkpoints to save disk space | |
| remove_unused_columns=False, # Necessary when dataset columns don't directly match model inputs | |
| dataloader_drop_last=True, # Drop the last incomplete batch for consistent batch sizes | |
| report_to=None, # Disable reporting to external services like Weights & Biases | |
| optim="adamw_torch", # AdamW optimizer (PyTorch implementation) | |
| lr_scheduler_type="linear", # Linear learning rate decay | |
| seed=42, # Random seed for reproducibility | |
| # Pin memory for faster data transfer between CPU and GPU if GPU is present | |
| dataloader_pin_memory=True if torch.cuda.is_available() else False, | |
| # For CPU, smaller log/eval steps are useful for frequent feedback without much overhead | |
| ) | |
| # Step 4: Initialize the Hugging Face Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| # Only pass eval_dataset if it's not empty, otherwise Trainer might error | |
| eval_dataset=eval_dataset if len(eval_dataset) > 0 else None, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, # Providing tokenizer to Trainer for internal operations (e.g., logging) | |
| ) | |
| # Add the custom progress callback to the trainer for UI updates | |
| trainer.add_callback(CustomProgressCallback(status_manager)) | |
| status_manager.update_status("π Starting training process...", 30) | |
| print("Training initiated...") | |
| # Step 5: Start the training loop | |
| train_result = trainer.train() | |
| # Step 6: Store final training results | |
| training_stats = train_result | |
| status_manager.set_model_trained(True) | |
| # Calculate and display total training time | |
| training_time = train_result.metrics.get('train_runtime', 0) | |
| minutes = int(training_time // 60) | |
| seconds = int(training_time % 60) | |
| final_loss = train_result.training_loss if hasattr(train_result, 'training_loss') else train_result.metrics.get('train_loss', 'N/A') | |
| success_msg = f"π Training completed! Final Loss: {final_loss:.4f}, Time: {minutes}m {seconds}s" | |
| status_manager.update_status(success_msg, 100) | |
| print(success_msg) | |
| except Exception as e: | |
| error_msg = f"β Training failed: {str(e)}" | |
| status_manager.update_status("β Training failed", 0, error_msg) | |
| print(error_msg) | |
| def start_training(batch_size, grad_accum, epochs, lr): | |
| """Initiates the model training process in a dedicated background thread.""" | |
| state = status_manager.get_status() | |
| if not state['model_loaded']: | |
| return "β Please load the model first!" | |
| if model is None: | |
| return "β Model not available!" | |
| # Start the training thread with the provided parameters | |
| thread = threading.Thread( | |
| target=train_model_background, | |
| args=(batch_size, grad_accum, epochs, lr), | |
| daemon=True # Daemon threads exit automatically when the main program exits | |
| ) | |
| thread.start() | |
| return "π Started training in background..." | |
| # --- Status Retrieval for UI --- | |
| def get_current_status(): | |
| """Retrieves the current application status from the StatusManager and formats it for Gradio display.""" | |
| state = status_manager.get_status() | |
| status_text = f"""π **Current Status**: {state['status']} | |
| π **Progress**: {state['progress']}% | |
| π€ **Model Loaded**: {'β ' if state['model_loaded'] else 'β'} | |
| π **Model Trained**: {'β ' if state['model_trained'] else 'β'}""" | |
| if state['error']: | |
| status_text += f"\nβ **Error**: {state['error']}" | |
| return status_text | |
| # --- Chat Functionality (Inference) --- | |
| def chat_with_model(message, history, temperature=1.0, max_tokens=200): | |
| """ | |
| Generates a conversational response from the loaded model based on the user's message | |
| and the ongoing chat history. | |
| """ | |
| global model, tokenizer | |
| state = status_manager.get_status() | |
| if not state['model_loaded'] or model is None: | |
| # If model is not loaded, return an error message to the user | |
| return history + [{"role": "user", "content": message}, | |
| {"role": "assistant", "content": "β Please load the model first!"}] | |
| # Build the conversation history into a single string, respecting roles | |
| conversation = "" | |
| # Optional: You can prepend a system message here for inference if your training didn't include it. | |
| # system_message = "Kai likita ne mai hankali da Ζwarewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n" | |
| # conversation += system_message | |
| for msg_pair in history: # history is a list of {"role": ..., "content": ...} dictionaries | |
| if msg_pair["role"] == "user": | |
| conversation += f"User: {msg_pair['content']}\n" | |
| else: # "assistant" | |
| conversation += f"Assistant: {msg_pair['content']}\n" | |
| # Append the current user's message and prompt the assistant for a response | |
| conversation += f"User: {message}\nAssistant:" | |
| try: | |
| # Tokenize the entire conversation string for model input | |
| inputs = tokenizer(conversation, return_tensors="pt").to(device) | |
| with torch.no_grad(): # Disable gradient calculations during inference for speed and memory efficiency | |
| # Generate the model's response | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, # Maximum number of tokens to generate for the response | |
| temperature=temperature, # Controls randomness: higher = more creative, lower = more deterministic | |
| top_p=0.95, # Nucleus sampling: sample from top P probability mass | |
| top_k=50, # Top-k sampling: sample from top K most probable tokens | |
| do_sample=True, # Enable sampling (otherwise uses greedy decoding) | |
| repetition_penalty=1.1, # Penalizes repeated tokens to avoid repetitive responses | |
| pad_token_id=tokenizer.pad_token_id, # Padding token ID | |
| eos_token_id=tokenizer.eos_token_id, # End-of-sequence token ID to stop generation | |
| ) | |
| # Decode the generated sequence back into human-readable text | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's actual response by removing the input prompt part | |
| assistant_response = full_response[len(conversation):].strip() | |
| # --- Post-processing for cleaner responses --- | |
| # Remove any leading "Assistant:" if the model generated it as part of the response | |
| if assistant_response.startswith("Assistant:"): | |
| assistant_response = assistant_response[len("Assistant:"):].strip() | |
| # Truncate response if the model starts generating a new 'User:' or 'System:' turn | |
| if "User:" in assistant_response: | |
| assistant_response = assistant_response.split("User:")[0].strip() | |
| if "System:" in assistant_response: | |
| assistant_response = assistant_response.split("System:")[0].strip() | |
| # Return the updated chat history including the new user message and assistant's response | |
| return history + [{"role": "user", "content": message}, | |
| {"role": "assistant", "content": assistant_response}] | |
| except Exception as e: | |
| error_msg = f"β Error during chat: {str(e)}" | |
| print(error_msg) | |
| return history + [{"role": "user", "content": message}, | |
| {"role": "assistant", "content": error_msg}] | |
| # --- Model Saving --- | |
| def save_model(output_path): | |
| """Saves the fine-tuned model and its tokenizer to a specified local directory.""" | |
| global model, tokenizer | |
| state = status_manager.get_status() | |
| if not state['model_trained']: | |
| return "β Please complete training first before saving!" | |
| try: | |
| # Save the PEFT model and the tokenizer | |
| model.save_pretrained(output_path) | |
| tokenizer.save_pretrained(output_path) | |
| return f"β Model saved to {output_path}!" | |
| except Exception as e: | |
| return f"β Error saving model: {str(e)}" | |
| # --- Sample Queries for UI --- | |
| SAMPLE_QUERIES = [ | |
| "Ina jin ciwon kai da zazzabi tun kwana biyu. Me ya kamata in yi?", | |
| "Dana yana da gudawa sosai. Ina bukatan taimako.", | |
| "Yaya ake hana malaria lokacin damina?", | |
| "Ina da ciwon sukari. Wanne abinci ya dace da ni?", | |
| ] | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(title="Hausa Health Assistant", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π₯ Hausa Health Assistant") | |
| gr.Markdown("Train and test an AI health assistant in Hausa language") | |
| # Display for current application status | |
| status_display = gr.Markdown("π **Current Status**: Ready") | |
| with gr.Tabs(): | |
| # --- Model Management Tab --- | |
| with gr.TabItem("π€ Model Management"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| load_btn = gr.Button("π Load Base Model", variant="primary") | |
| prep_btn = gr.Button("βοΈ Prepare for Training (LoRA)") | |
| gr.Markdown("### Training Parameters") | |
| # Adjusted default parameters for more feasible CPU testing | |
| batch_size = gr.Slider(1, 4, 1, step=1, label="Batch Size", info="Per device batch size for training. Start with 1 on CPU to avoid OOM.") | |
| grad_accum = gr.Slider(1, 8, 1, step=1, label="Gradient Accumulation", info="Number of updates steps to accumulate gradients for. Start with 1 for debugging.") | |
| epochs = gr.Slider(1, 3, 1, step=1, label="Epochs", info="Number of training epochs. Start with 1 for initial tests.") | |
| learning_rate = gr.Slider(1e-5, 5e-4, 2e-4, label="Learning Rate", info="Initial learning rate for the optimizer.") | |
| train_btn = gr.Button("π― Start Training", variant="primary") | |
| gr.Markdown("### Save Model") | |
| save_path = gr.Textbox(value="hausa-health-assistant-finetuned", label="Save Path", info="Directory to save the fine-tuned model and tokenizer.") | |
| save_btn = gr.Button("πΎ Save Model") | |
| with gr.Column(): | |
| operation_status = gr.Textbox(label="Operation Status Log", lines=3, interactive=False, value="Awaiting operations...") | |
| # --- Chat Tab --- | |
| with gr.TabItem("π¬ Chat"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Sample Queries") | |
| sample_btns = [] # Store references to sample query buttons | |
| for query_text in SAMPLE_QUERIES: | |
| btn = gr.Button(f"{query_text[:40]}...", size="sm") | |
| sample_btns.append((btn, query_text)) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Health Assistant", height=400, type="messages", layout="bubble") | |
| msg = gr.Textbox(label="Your Message (Hausa)", placeholder="Ka rubuta tambayarka...") | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 2.0, 1.0, label="Temperature", info="Controls the randomness of the model's output (higher = more creative).") | |
| tokens = gr.Slider(50, 500, 200, label="Max Tokens", info="Maximum number of new tokens to generate in the response.") | |
| with gr.Row(): | |
| send_btn = gr.Button("π€ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Chat") | |
| # --- Event Handlers for Gradio Components --- | |
| load_btn.click(start_model_loading, outputs=[operation_status]) | |
| prep_btn.click(prepare_model_for_training, outputs=[operation_status]) | |
| train_btn.click( | |
| start_training, | |
| inputs=[batch_size, grad_accum, epochs, learning_rate], | |
| outputs=[operation_status] | |
| ) | |
| save_btn.click(save_model, inputs=[save_path], outputs=[operation_status]) | |
| # Chat interaction | |
| send_btn.click( | |
| chat_with_model, | |
| inputs=[msg, chatbot, temp, tokens], | |
| outputs=[chatbot] | |
| ).then(lambda: gr.Textbox(value="", interactive=True), outputs=[msg]) # Clear input after sending | |
| clear_btn.click(lambda: [], outputs=[chatbot]) # Clear chat history | |
| # Attach click handlers to dynamically created sample query buttons | |
| for btn, query in sample_btns: | |
| # Use a lambda with a default argument to capture the current query value | |
| btn.click(lambda q=query: gr.Textbox(value=q, interactive=True), inputs=[], outputs=[msg]) | |
| # Auto-update status display every 2 seconds using a Gradio Timer | |
| status_timer = gr.Timer(value=2) | |
| status_timer.tick(get_current_status, outputs=[status_display]) | |
| # --- Main Application Entry Point --- | |
| def main(): | |
| """Main function to parse command-line arguments and launch the Gradio application.""" | |
| parser = argparse.ArgumentParser(description="Hausa Health Assistant Training App") | |
| # Argument to enable Gradio's shareable link feature (defaulting to False) | |
| parser.add_argument("--share", action="store_true", default=False, help="Create a shareable link for the Gradio app.") | |
| args = parser.parse_args() | |
| # Clear CUDA cache at startup if a GPU is available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Set environment variable to suppress tokenizers parallelism warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| print(f"Starting Hausa Health Assistant app...") | |
| # Launch the Gradio app, allowing access from any IP (0.0.0.0) and a specific port. | |
| # 'share=args.share' will create a publicly shareable link if the --share flag is used. | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=args.share) | |
| if __name__ == "__main__": | |
| main() | |