import os import argparse import torch import gradio as gr import threading import time from datetime import datetime from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, TrainerState, TrainerControl from datasets import load_dataset from peft import get_peft_model, LoraConfig, TaskType, PeftModel # Imported PeftModel for robust check import random # Set seed for reproducibility for consistent results random.seed(42) torch.manual_seed(42) # Determine the device for model execution. Prioritize CUDA (GPU) if available, otherwise use CPU. device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Global variables to store the loaded model, tokenizer, and training-related data model = None tokenizer = None training_stats = {} trainer = None train_dataset = None eval_dataset = None # --- Status Management (Thread-Safe) --- # A class to manage and update the application's status, progress, and error messages # in a thread-safe manner, as background operations will modify these. class StatusManager: def __init__(self): self._lock = threading.Lock() # Ensures only one thread can modify status at a time self.status = "Ready" # Current descriptive status message self.progress = 0 # Progress percentage (0-100) self.model_loaded = False # Boolean flag: Is the base model loaded? self.model_trained = False # Boolean flag: Has the model completed training? self.error = None # Stores any error message encountered def update_status(self, status: str, progress: int = None, error: str = None): """Updates the current status, optional progress percentage, and optional error message.""" with self._lock: # Acquire lock before modifying shared state self.status = status if progress is not None: self.progress = progress if error is not None: self.error = error def set_model_loaded(self, loaded: bool): """Sets the flag indicating whether the model has been loaded.""" with self._lock: self.model_loaded = loaded def set_model_trained(self, trained: bool): """Sets the flag indicating whether the model has completed training.""" with self._lock: self.model_trained = trained def get_status(self): """Returns a dictionary containing the current status, progress, and flags.""" with self._lock: # Acquire lock before reading shared state return { 'status': self.status, 'progress': self.progress, 'model_loaded': self.model_loaded, 'model_trained': self.model_trained, 'error': self.error } status_manager = StatusManager() # --- Model Loading --- def initialize_model_background(): """ Loads the base pre-trained language model (distilgpt2) and its tokenizer in a background thread to keep the Gradio UI responsive. """ global model, tokenizer try: status_manager.update_status("πŸ”„ Loading base distilgpt2 model...", 10) # Clear CUDA cache if a GPU is available to free up memory before loading a new model if torch.cuda.is_available(): torch.cuda.empty_cache() status_manager.update_status("πŸ”„ Downloading model weights (this might take a while)...", 30) # Changed model to distilgpt2 for lighter computation model_name = "distilgpt2" # Load the tokenizer associated with the model tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, # Allow custom code in model config if necessary ) # Ensure the tokenizer has a padding token, which is crucial for batch processing during training if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Use EOS token as pad token if not defined status_manager.update_status("πŸ”„ Loading model into memory...", 50) # Load the causal language model model = AutoModelForCausalLM.from_pretrained( model_name, # Use float16 for GPU (half precision for faster computation, lower memory), float32 for CPU torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # 'device_map="auto"' intelligently distributes model layers across available GPUs. # For CPU, it should be None, and then explicitly moved. device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) status_manager.update_status("πŸ”„ Moving model to target device (GPU/CPU)...", 90) # If not using 'device_map="auto"' (i.e., on CPU), explicitly move the model to the target device if not torch.cuda.is_available(): model = model.to(device) status_manager.update_status("βœ… Model loaded successfully!", 100) status_manager.set_model_loaded(True) # Report the number of parameters to give an idea of model size param_count = sum(p.numel() for p in model.parameters()) print(f"βœ… Model initialized! Parameters: {param_count/1e6:.2f}M") # Display in millions except Exception as e: error_msg = f"❌ Error loading model: {str(e)}" status_manager.update_status("❌ Model loading failed", 0, error_msg) print(error_msg) def start_model_loading(): """Initiates the model loading process in a background thread.""" if status_manager.get_status()['model_loaded']: return "Model already loaded!" # Prevent loading multiple times thread = threading.Thread(target=initialize_model_background, daemon=True) thread.start() return "πŸš€ Started loading model in background..." # --- Model Preparation for Training (LoRA) --- def prepare_model_for_training(): """ Applies LoRA (Low-Rank Adaptation) adapters to the base model. This makes the model more memory-efficient and faster to fine-tune. """ global model state = status_manager.get_status() if not state['model_loaded']: return "❌ Please load the model first!" if model is None: return "❌ Model not available!" try: status_manager.update_status("πŸ”„ Configuring LoRA adapters...", 0) # Check if LoRA adapters are already applied. if isinstance(model, PeftModel): status_manager.update_status("βœ… Model already prepared for training", 100) return "βœ… Model already prepared for training" # Define LoRA configuration. Target modules are specific to distilgpt2's architecture. lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=8, # LoRA attention dimension (e.g., 8, 16, 32) lora_alpha=16, # Alpha parameter for LoRA scaling lora_dropout=0.1, # Dropout probability for LoRA layers bias="none", # Bias type (none, all, lora_only) # Adjusted target modules for distilgpt2 target_modules=["c_attn", "c_proj", "c_fc"], ) # Apply LoRA to the base model, making only a small portion trainable model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Prints a summary of trainable vs. total parameters status_manager.update_status("βœ… LoRA configuration applied!", 100) return "βœ… LoRA configuration applied! Model ready for training." except Exception as e: error_msg = f"❌ Error preparing model for LoRA: {str(e)}" status_manager.update_status("❌ Model preparation failed", 0, error_msg) return error_msg # --- Dataset Loading and Preprocessing --- def format_chat_template(example): """ Formats a single conversation example from the dataset into a plain text string that the language model can understand for training. """ if 'conversations' not in example or not example['conversations']: return {"text": ""} formatted_text = "" # Optional: Add a system message at the beginning of each conversation # formatted_text += "System: Kai likita ne mai hankali da Ζ™warewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n" for turn in example['conversations']: role = turn.get('from', turn.get('role', '')) # Handles different key names for roles content = turn.get('value', turn.get('content', '')) # Handles different key names for content if role.lower() == 'human' or role.lower() == 'user': formatted_text += f"User: {content}\n" elif role.lower() == 'gpt' or role.lower() == 'assistant': formatted_text += f"Assistant: {content}\n" # Append the EOS token to mark the end of a conversation for the model formatted_text += tokenizer.eos_token return {"text": formatted_text} def load_dataset_background(): """ Loads the Hausa medical conversations dataset from Hugging Face Hub and preprocesses it for training, all in a background thread. """ global train_dataset, eval_dataset try: status_manager.update_status("πŸ”„ Loading Hausa medical dataset from Hugging Face Hub...", 10) dataset_name = "ictbiortc/hausa-medical-conversations-format-9k" dataset = load_dataset(dataset_name) if dataset is None: raise ValueError("Dataset not found or could not be loaded.") status_manager.update_status("πŸ”„ Processing dataset (formatting conversations)...", 40) # If the dataset doesn't explicitly have a 'test' split, create one if 'test' not in dataset: print("No 'test' split found in dataset, creating a 10% test split from 'train'.") dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) train_dataset_raw = dataset['train'] eval_dataset_raw = dataset['test'] else: train_dataset_raw = dataset['train'] eval_dataset_raw = dataset['test'] # Apply the chat formatting function to both train and evaluation splits train_dataset = train_dataset_raw.map( format_chat_template, remove_columns=train_dataset_raw.column_names, # Remove original columns to keep only 'text' desc="Formatting train dataset" ) eval_dataset = eval_dataset_raw.map( format_chat_template, remove_columns=eval_dataset_raw.column_names, desc="Formatting eval dataset" ) # Filter out any examples that resulted in empty text after formatting train_dataset = train_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty train examples") eval_dataset = eval_dataset.filter(lambda x: len(x['text'].strip()) > 0, desc="Filtering empty eval examples") status_manager.update_status(f"βœ… Dataset loaded! Train samples: {len(train_dataset)}, Validation samples: {len(eval_dataset)}", 100) print(f"Dataset loading complete: Train samples={len(train_dataset)}, Eval samples={len(eval_dataset)}") except Exception as e: error_msg = f"❌ Error loading or processing dataset: {str(e)}" status_manager.update_status("❌ Dataset loading failed", 0, error_msg) print(error_msg) train_dataset = None eval_dataset = None # --- Custom Data Collator --- # A data collator is necessary for dynamic padding of sequences within a batch, # ensuring all sequences in a batch have the same length for efficient processing. class DataCollator: def __init__(self, tokenizer, max_length=512): self.tokenizer = tokenizer self.max_length = max_length def __call__(self, examples): texts = [example['text'] for example in examples] # Tokenize the batch of texts tokenized = self.tokenizer( texts, truncation=True, # Truncate sequences longer than max_length padding="max_length", # Pad to max_length for consistency within the batch max_length=self.max_length, return_tensors="pt" # Return PyTorch tensors ) # For causal language modeling, the labels are typically the same as the input_ids # (the model predicts the next token in the sequence). tokenized['labels'] = tokenized['input_ids'].clone() return tokenized # --- Custom Training Progress Callback --- class CustomProgressCallback(TrainerCallback): """ A custom callback for the Hugging Face Trainer. It updates the Gradio UI's status display with real-time training progress, including steps, percentage, and current loss. """ def __init__(self, status_manager_instance): self.status_manager = status_manager_instance self.last_logged_progress = -1 self.last_logged_loss = None def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs): """Called whenever the Trainer logs something (e.g., loss, learning rate).""" if logs is None: return current_loss = logs.get('loss') if current_loss is not None: self.last_logged_loss = current_loss # Update progress based on global step, but be mindful of initial 0% if state.max_steps > 0: progress = int((state.global_step / state.max_steps) * 100) # Only update if progress has increased or if it's the very first log if progress != self.last_logged_progress or state.global_step == 1: loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else "" self.status_manager.update_status( f"πŸš€ Training... Step {state.global_step}/{state.max_steps}{loss_info}", progress ) self.last_logged_progress = progress else: # Fallback if max_steps isn't set yet or is 0 self.status_manager.update_status(f"πŸš€ Training... Step {state.global_step}", None) def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """Called at the end of each training step.""" # Ensure that `on_log` captures the loss, and `on_step_end` provides frequent updates. # This will mainly ensure the progress bar updates even if loss isn't logged every single step. if state.max_steps > 0: progress = int((state.global_step / state.max_steps) * 100) if progress > self.last_logged_progress or (state.global_step == 1 and self.last_logged_progress == -1): loss_info = f", Loss: {self.last_logged_loss:.4f}" if self.last_logged_loss is not None else "" self.status_manager.update_status( f"πŸš€ Training... Step {state.global_step}/{state.max_steps}{loss_info}", progress ) self.last_logged_progress = progress def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """Called at the end of each training epoch.""" self.status_manager.update_status(f"βœ… Epoch {int(state.epoch)} completed.", int((state.global_step / state.max_steps) * 100) if state.max_steps > 0 else None) def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """Called at the very end of the training process.""" # The final status is set by train_model_background, this is just for intermediate clarity. pass # --- Model Training Function --- def train_model_background(batch_size, grad_accum, epochs, lr): """ Manages the entire model training lifecycle in a background thread: loading dataset (if needed), setting up the Trainer, and initiating training. """ global model, tokenizer, trainer, training_stats, train_dataset, eval_dataset # Disable PyTorch anomaly detection for faster training. # Re-enable if in-place modification errors persist with the new model. # torch.autograd.set_detect_anomaly(True) # print("PyTorch anomaly detection is ENABLED. Training may be slower but will provide detailed error traces.") print("PyTorch anomaly detection is DISABLED for faster training.") try: # Step 1: Ensure dataset is loaded and ready if train_dataset is None or eval_dataset is None: status_manager.update_status("πŸ”„ Loading dataset for training...", 5) load_dataset_background() # Call the background dataset loader if train_dataset is None or len(train_dataset) == 0: error_msg = "❌ Training dataset is empty or failed to load - training cannot proceed." status_manager.update_status("❌ Training failed", 0, error_msg) return if eval_dataset is None or len(eval_dataset) == 0: print("Warning: Evaluation dataset is empty or failed to load. Evaluation during training will be skipped.") status_manager.update_status("πŸ”„ Setting up Hugging Face Trainer...", 10) print(f"Trainer setup: Batch size={batch_size}, Grad Accum={grad_accum}, Epochs={epochs}, LR={lr}") # Step 2: Initialize Data Collator data_collator = DataCollator(tokenizer) print("Data collator initialized.") # Step 3: Configure Training Arguments training_args = TrainingArguments( output_dir="./results", # Directory for saving checkpoints per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, # Use same batch size for evaluation gradient_accumulation_steps=grad_accum, num_train_epochs=epochs, learning_rate=lr, warmup_steps=10, # Number of steps for linear learning rate warmup logging_steps=1, # Log training progress every step (important for granular feedback) save_steps=100, # Save checkpoint more frequently for CPU tests eval_steps=100, # Run evaluation more frequently for CPU tests save_total_limit=2, # Keep only the last 2 checkpoints to save disk space remove_unused_columns=False, # Necessary when dataset columns don't directly match model inputs dataloader_drop_last=True, # Drop the last incomplete batch for consistent batch sizes report_to=None, # Disable reporting to external services like Weights & Biases optim="adamw_torch", # AdamW optimizer (PyTorch implementation) lr_scheduler_type="linear", # Linear learning rate decay seed=42, # Random seed for reproducibility # Pin memory for faster data transfer between CPU and GPU if GPU is present dataloader_pin_memory=True if torch.cuda.is_available() else False, # For CPU, smaller log/eval steps are useful for frequent feedback without much overhead ) # Step 4: Initialize the Hugging Face Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, # Only pass eval_dataset if it's not empty, otherwise Trainer might error eval_dataset=eval_dataset if len(eval_dataset) > 0 else None, data_collator=data_collator, tokenizer=tokenizer, # Providing tokenizer to Trainer for internal operations (e.g., logging) ) # Add the custom progress callback to the trainer for UI updates trainer.add_callback(CustomProgressCallback(status_manager)) status_manager.update_status("πŸš€ Starting training process...", 30) print("Training initiated...") # Step 5: Start the training loop train_result = trainer.train() # Step 6: Store final training results training_stats = train_result status_manager.set_model_trained(True) # Calculate and display total training time training_time = train_result.metrics.get('train_runtime', 0) minutes = int(training_time // 60) seconds = int(training_time % 60) final_loss = train_result.training_loss if hasattr(train_result, 'training_loss') else train_result.metrics.get('train_loss', 'N/A') success_msg = f"πŸŽ‰ Training completed! Final Loss: {final_loss:.4f}, Time: {minutes}m {seconds}s" status_manager.update_status(success_msg, 100) print(success_msg) except Exception as e: error_msg = f"❌ Training failed: {str(e)}" status_manager.update_status("❌ Training failed", 0, error_msg) print(error_msg) def start_training(batch_size, grad_accum, epochs, lr): """Initiates the model training process in a dedicated background thread.""" state = status_manager.get_status() if not state['model_loaded']: return "❌ Please load the model first!" if model is None: return "❌ Model not available!" # Start the training thread with the provided parameters thread = threading.Thread( target=train_model_background, args=(batch_size, grad_accum, epochs, lr), daemon=True # Daemon threads exit automatically when the main program exits ) thread.start() return "πŸš€ Started training in background..." # --- Status Retrieval for UI --- def get_current_status(): """Retrieves the current application status from the StatusManager and formats it for Gradio display.""" state = status_manager.get_status() status_text = f"""πŸ“Š **Current Status**: {state['status']} πŸ“ˆ **Progress**: {state['progress']}% πŸ€– **Model Loaded**: {'βœ…' if state['model_loaded'] else '❌'} πŸŽ“ **Model Trained**: {'βœ…' if state['model_trained'] else '❌'}""" if state['error']: status_text += f"\n❌ **Error**: {state['error']}" return status_text # --- Chat Functionality (Inference) --- def chat_with_model(message, history, temperature=1.0, max_tokens=200): """ Generates a conversational response from the loaded model based on the user's message and the ongoing chat history. """ global model, tokenizer state = status_manager.get_status() if not state['model_loaded'] or model is None: # If model is not loaded, return an error message to the user return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please load the model first!"}] # Build the conversation history into a single string, respecting roles conversation = "" # Optional: You can prepend a system message here for inference if your training didn't include it. # system_message = "Kai likita ne mai hankali da Ζ™warewa a fannin kiwon lafiya. Ka ba da shawarwari masu tushe a kimiyya, masu dacewa da al'adun Nijeriya.\n" # conversation += system_message for msg_pair in history: # history is a list of {"role": ..., "content": ...} dictionaries if msg_pair["role"] == "user": conversation += f"User: {msg_pair['content']}\n" else: # "assistant" conversation += f"Assistant: {msg_pair['content']}\n" # Append the current user's message and prompt the assistant for a response conversation += f"User: {message}\nAssistant:" try: # Tokenize the entire conversation string for model input inputs = tokenizer(conversation, return_tensors="pt").to(device) with torch.no_grad(): # Disable gradient calculations during inference for speed and memory efficiency # Generate the model's response outputs = model.generate( **inputs, max_new_tokens=max_tokens, # Maximum number of tokens to generate for the response temperature=temperature, # Controls randomness: higher = more creative, lower = more deterministic top_p=0.95, # Nucleus sampling: sample from top P probability mass top_k=50, # Top-k sampling: sample from top K most probable tokens do_sample=True, # Enable sampling (otherwise uses greedy decoding) repetition_penalty=1.1, # Penalizes repeated tokens to avoid repetitive responses pad_token_id=tokenizer.pad_token_id, # Padding token ID eos_token_id=tokenizer.eos_token_id, # End-of-sequence token ID to stop generation ) # Decode the generated sequence back into human-readable text full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's actual response by removing the input prompt part assistant_response = full_response[len(conversation):].strip() # --- Post-processing for cleaner responses --- # Remove any leading "Assistant:" if the model generated it as part of the response if assistant_response.startswith("Assistant:"): assistant_response = assistant_response[len("Assistant:"):].strip() # Truncate response if the model starts generating a new 'User:' or 'System:' turn if "User:" in assistant_response: assistant_response = assistant_response.split("User:")[0].strip() if "System:" in assistant_response: assistant_response = assistant_response.split("System:")[0].strip() # Return the updated chat history including the new user message and assistant's response return history + [{"role": "user", "content": message}, {"role": "assistant", "content": assistant_response}] except Exception as e: error_msg = f"❌ Error during chat: {str(e)}" print(error_msg) return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}] # --- Model Saving --- def save_model(output_path): """Saves the fine-tuned model and its tokenizer to a specified local directory.""" global model, tokenizer state = status_manager.get_status() if not state['model_trained']: return "❌ Please complete training first before saving!" try: # Save the PEFT model and the tokenizer model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) return f"βœ… Model saved to {output_path}!" except Exception as e: return f"❌ Error saving model: {str(e)}" # --- Sample Queries for UI --- SAMPLE_QUERIES = [ "Ina jin ciwon kai da zazzabi tun kwana biyu. Me ya kamata in yi?", "Dana yana da gudawa sosai. Ina bukatan taimako.", "Yaya ake hana malaria lokacin damina?", "Ina da ciwon sukari. Wanne abinci ya dace da ni?", ] # --- Gradio Interface Definition --- with gr.Blocks(title="Hausa Health Assistant", theme=gr.themes.Soft()) as app: gr.Markdown("# πŸ₯ Hausa Health Assistant") gr.Markdown("Train and test an AI health assistant in Hausa language") # Display for current application status status_display = gr.Markdown("πŸ“Š **Current Status**: Ready") with gr.Tabs(): # --- Model Management Tab --- with gr.TabItem("πŸ€– Model Management"): with gr.Row(): with gr.Column(): load_btn = gr.Button("πŸš€ Load Base Model", variant="primary") prep_btn = gr.Button("βš™οΈ Prepare for Training (LoRA)") gr.Markdown("### Training Parameters") # Adjusted default parameters for more feasible CPU testing batch_size = gr.Slider(1, 4, 1, step=1, label="Batch Size", info="Per device batch size for training. Start with 1 on CPU to avoid OOM.") grad_accum = gr.Slider(1, 8, 1, step=1, label="Gradient Accumulation", info="Number of updates steps to accumulate gradients for. Start with 1 for debugging.") epochs = gr.Slider(1, 3, 1, step=1, label="Epochs", info="Number of training epochs. Start with 1 for initial tests.") learning_rate = gr.Slider(1e-5, 5e-4, 2e-4, label="Learning Rate", info="Initial learning rate for the optimizer.") train_btn = gr.Button("🎯 Start Training", variant="primary") gr.Markdown("### Save Model") save_path = gr.Textbox(value="hausa-health-assistant-finetuned", label="Save Path", info="Directory to save the fine-tuned model and tokenizer.") save_btn = gr.Button("πŸ’Ύ Save Model") with gr.Column(): operation_status = gr.Textbox(label="Operation Status Log", lines=3, interactive=False, value="Awaiting operations...") # --- Chat Tab --- with gr.TabItem("πŸ’¬ Chat"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Sample Queries") sample_btns = [] # Store references to sample query buttons for query_text in SAMPLE_QUERIES: btn = gr.Button(f"{query_text[:40]}...", size="sm") sample_btns.append((btn, query_text)) with gr.Column(scale=2): chatbot = gr.Chatbot(label="Health Assistant", height=400, type="messages", layout="bubble") msg = gr.Textbox(label="Your Message (Hausa)", placeholder="Ka rubuta tambayarka...") with gr.Row(): temp = gr.Slider(0.1, 2.0, 1.0, label="Temperature", info="Controls the randomness of the model's output (higher = more creative).") tokens = gr.Slider(50, 500, 200, label="Max Tokens", info="Maximum number of new tokens to generate in the response.") with gr.Row(): send_btn = gr.Button("πŸ“€ Send", variant="primary") clear_btn = gr.Button("πŸ—‘οΈ Clear Chat") # --- Event Handlers for Gradio Components --- load_btn.click(start_model_loading, outputs=[operation_status]) prep_btn.click(prepare_model_for_training, outputs=[operation_status]) train_btn.click( start_training, inputs=[batch_size, grad_accum, epochs, learning_rate], outputs=[operation_status] ) save_btn.click(save_model, inputs=[save_path], outputs=[operation_status]) # Chat interaction send_btn.click( chat_with_model, inputs=[msg, chatbot, temp, tokens], outputs=[chatbot] ).then(lambda: gr.Textbox(value="", interactive=True), outputs=[msg]) # Clear input after sending clear_btn.click(lambda: [], outputs=[chatbot]) # Clear chat history # Attach click handlers to dynamically created sample query buttons for btn, query in sample_btns: # Use a lambda with a default argument to capture the current query value btn.click(lambda q=query: gr.Textbox(value=q, interactive=True), inputs=[], outputs=[msg]) # Auto-update status display every 2 seconds using a Gradio Timer status_timer = gr.Timer(value=2) status_timer.tick(get_current_status, outputs=[status_display]) # --- Main Application Entry Point --- def main(): """Main function to parse command-line arguments and launch the Gradio application.""" parser = argparse.ArgumentParser(description="Hausa Health Assistant Training App") # Argument to enable Gradio's shareable link feature (defaulting to False) parser.add_argument("--share", action="store_true", default=False, help="Create a shareable link for the Gradio app.") args = parser.parse_args() # Clear CUDA cache at startup if a GPU is available if torch.cuda.is_available(): torch.cuda.empty_cache() # Set environment variable to suppress tokenizers parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" print(f"Starting Hausa Health Assistant app...") # Launch the Gradio app, allowing access from any IP (0.0.0.0) and a specific port. # 'share=args.share' will create a publicly shareable link if the --share flag is used. app.launch(server_name="0.0.0.0", server_port=7860, share=args.share) if __name__ == "__main__": main()