| ## Overview | |
| Advances in the field of biomedicine have been a major focus of the AI community for many years. With the recent rise of large language models (LLMs), more applications are now leveraging these innovations to support biomedicine and healthcare. | |
| We present **MedConclusion**, a fine-tuned LLM based on Phi3-medium, trained on over 250,000 PubMed articles. *MedConclusion* processes research articles to generate clear and concise conclusions. This specialized model has a wide range of potential applications, including: | |
| ### Applications | |
| 1. **Clinical Decision Support Systems:** By integrating the model into hospital systems or pharmaceutical research institutes, healthcare providers can access summarized research findings to support evidence-based clinical decisions. | |
| 2. **Medical Education and Training:** providing students with concise summaries that highlight key information. | |
| 3. **Public Health Policy Development:** Policymakers can use research summaries to make informed decisions about public health strategies. | |
| 4. **Research Recommendation Systems:** helping researchers stay updated with minimal effort. | |
| 5. **Biomedical Search Engines:** Enhanced search results with auto-generated conclusions offer users quick insights into the relevance and key takeaways of each paper. | |
| 6. **Academic Tools and Writing Assistance:** Academic platforms can use the model to create quick summaries for papers under review, improving their visibility and accessibility. | |
| ## Training Dataset | |
| MedConclusion is fine-tuned using the PubMedQA dataset, specifically on the PQA-A and PQA-U training subsets. | |
| For a complete description of the PubMedQA dataset, please visit the original source: https://github.com/pubmedqa/pubmedqa | |
| The instructions and training datasets used are: | |
| - `pqau_genconc.jsonl` and `pqaa_genconc.jsonl` for training. | |
| - `testset_genconc.jsonl` for validation. | |
| - `pubmedqa_testset.csv` for inference. | |
| The datasets are available under `med-rcq/med-rcq-dataset`: https://huggingface.co/datasets/med-rcq/med-rcq-dataset/tree/main | |
| ## Training and Inference Parameters | |
| | Training Parameter | MedConclusion | | |
| |----------------------------------|---------------| | |
| | Learning rate | 2e-04 | | |
| | Seed | 42 | | |
| | Scheduler | cosine | | |
| | Warmup Ratio | 0.05 | | |
| | Optimizer | AdamW | | |
| | Gradient Accumulation steps | 4 | | |
| | Train batch size per device | 8 | | |
| | Effective Batch size | 192 | | |
| | Evaluation batch size | 4 | | |
| | Cut-off Length | 1024 | | |
| | Number of GPU | 6 | | |
| | LoRA Rank | 32 | | |
| | LoRA Alpha | 32 | | |
| | LoRA Dropout | 0.05 | | |
| | LoRA Target | All | | |
| | Number of epochs | 1 | | |
| | Max Grad Norm | 1.0 | | |
| | **Training Time** | **11.5 hours** | | |
| | Inference Parameter | MedConclusion | | |
| |----------------------|---------------| | |
| | Temperature | 0.01 | | |
| | Max Token | 250 | | |
| ## Environment Setup | |
| - OS: Ubuntu 22.04.3 | |
| - GPU: A40 or RTX A6000, CUDA 12.4 | |
| - To setup the environment use the following commands: | |
| curl -O https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh; /bin/bash Anaconda3-2024.02-1-Linux-x86_64.sh -b -p /opt/conda; source ~/.bashrc; export PATH=/opt/conda/bin:$PATH; source /opt/conda/bin/activate; conda create -n medrcq_env python=3.11.7 -y; conda activate medrcq_env; pip install torch==2.5.1 transformers==4.48.0 pandas==2.1.4;pip install flash-attn==2.7.3 | |
| - The code are detailed below. | |
| ```python | |
| from transformers import pipeline,set_seed | |
| import torch | |
| import pandas as pd | |
| import argparse | |
| MODEL_PATH="med-rcq/MedConclusion" | |
| set_seed(42) | |
| SYSTEM_PROMPT='''You are a helpful medical assistant. Write a conclusion for the following article:\n | |
| Title: _TITLE_ | |
| _CONTEXT_ | |
| Conclusion:''' | |
| pipe = pipeline( | |
| "text-generation", | |
| model=MODEL_PATH, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| trust_remote_code=True, | |
| do_sample=True, | |
| temperature=0.01, | |
| device="cuda", # replace with "mps" to run on a Mac device | |
| ) | |
| def generate_ai_conclusion(prompt): | |
| """ | |
| Generates medical conclusion based on a given prompt. | |
| Args: | |
| prompt (str): The input prompt for the model. | |
| Returns: | |
| str: The generated conclusion. | |
| """ | |
| messages = [{"role": "user", "content": prompt}] | |
| outputs = pipe(messages, max_new_tokens=250) | |
| assistant_response = outputs[0]["generated_text"][-1]["content"].strip() | |
| return assistant_response | |
| # Function to process the CSV file | |
| def process_csv(input_file, output_file): | |
| """ | |
| Processes a CSV file by generating medical conclusions for each row. Each row represent a pubmed article. | |
| Args: | |
| input_file (str): Path to the input CSV file. | |
| output_file (str): Path to save the processed CSV file with the generated conclusion | |
| """ | |
| # Read the input CSV file into a DataFrame | |
| try: | |
| df = pd.read_csv(input_file) | |
| except FileNotFoundError: | |
| print(f"Error: File {input_file} not found.") | |
| return | |
| except pd.errors.EmptyDataError: | |
| print("Error: Input file is empty.") | |
| return | |
| cols_order = ['ID', 'Question', 'Context_with_label','LONG_ANSWER','final_decision'] | |
| df = df[cols_order] | |
| # Loop over each row in the DataFrame | |
| for index, row in df.iterrows(): | |
| # Extract the relevant columns | |
| article_id = row['ID'] | |
| # Pubmed article title is in a Question form | |
| title = row['Question'] | |
| # context_string represent the Pubmed article without the conclusion or the title. It include the labels like "background", "Methods"...etc | |
| context_string=row['Context_with_label'] | |
| # long answer represent Pubmed article conclusion section | |
| long_answer = row['LONG_ANSWER'] | |
| # The final decision is either yes or no or maybe, it depends on what both annotators agreed on | |
| final_decision = row['final_decision'] | |
| print("\n########## INDEX:"+str(index)+" ## QID:"+str(article_id)+" ##########\n") | |
| #Prepare system prompt | |
| formatted_prompt=SYSTEM_PROMPT.replace("_TITLE_",title) | |
| formatted_prompt=formatted_prompt.replace("_CONTEXT_",context_string) | |
| # Call the LLM model to generate the conclusion using the constructed prompt | |
| generated_conclusion = generate_ai_conclusion(formatted_prompt) | |
| print(generated_conclusion) | |
| #Save the generated conclusion | |
| df.at[index, 'Medconc_Generated_conclusion'] = generated_conclusion | |
| # Save output in CSV file | |
| df.to_csv(output_file, index=False) | |
| if __name__ == "__main__": | |
| # output_file | |
| parser = argparse.ArgumentParser(description="Process a CSV file to generate conclusions.") | |
| # input_file = 'pubmedqa_testset.csv' | |
| parser.add_argument("input_file", help="Path to the input CSV file") | |
| # write the name of the output file | |
| parser.add_argument("output_file", help="Path to save the processed CSV file") | |
| args = parser.parse_args() | |
| input_file = args.input_file | |
| output_file = args.output_file | |
| # Process the file | |
| process_csv(input_file, output_file) | |