| import gradio as gr |
| from faster_whisper import WhisperModel |
| from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError |
| from typing import List |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| import csv |
| import json |
| import tempfile |
| import torch |
| import os |
|
|
| |
| os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue" |
|
|
|
|
|
|
| |
| numind_checkpoint = "numind/NuExtract-tiny" |
| llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction" |
| whisper_checkpoint = "base" |
|
|
| quantization_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| ) |
|
|
| |
| whisper_model = WhisperModel(whisper_checkpoint, device="cuda") |
| numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True) |
| numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint) |
| llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True) |
| llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint) |
|
|
| |
| def transcribe_audio(audio_file_path): |
| try: |
| segments, info = whisper_model.transcribe(audio_file_path, beam_size=5) |
| text = "".join([segment.text for segment in segments]) |
| return text |
| except Exception as e: |
| return str(e) |
|
|
| |
| def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]): |
| schema = json.dumps(json.loads(schema), indent=4) |
| input_llm = "<|input|>\n### Template:\n" + schema + "\n" |
| for i in example: |
| if i != "": |
| input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n" |
|
|
| input_llm += "### Text:\n"+text +"\n<|output|>\n" |
| input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda") |
|
|
| output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True) |
| return output.split("<|output|>")[1].split("<|end-output|>")[0] |
|
|
|
|
| |
| def prompt_format(text): |
| prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
| |
| ### Instruction: |
| {} |
| |
| ### Input: |
| {} |
| |
| ### Response: |
| {}""" |
|
|
| instruction = """Extract the following entities from the medical conversation: |
| * **Symptoms:** List all the symptoms the patient mentions. |
| * **Diagnosis:** List the doctor's diagnosis or potential diagnoses. |
| * **Medical History:** Summarize the patient's relevant medical history. |
| * **Action Plan:** List the recommended actions or treatment plan. |
| |
| Provide the result in the following JSON format: |
| { |
| "Symptoms": [...], |
| "Diagnosis": [...], |
| "Medical history": [...], |
| "Action plan": [...] |
| }""" |
| full_prompt = prompt.format(instruction, text, "") |
| return full_prompt |
|
|
|
|
| |
| def validate_medical_record(response): |
|
|
| class MedicalRecord(BaseModel): |
| Symptoms: List[str] = Field(default_factory=list) |
| Diagnosis: List[str] = Field(default_factory=list) |
| Medical_history: List[str] = Field( |
| default_factory=list, |
| validation_alias=AliasChoices('Medical history', 'History of Patient') |
| ) |
| Action_plan: List[str] = Field( |
| default_factory=list, |
| validation_alias=AliasChoices('Action plan', 'Plan of Action') |
| ) |
|
|
| @field_validator('*', mode='before') |
| def ensure_list(cls, v): |
| if isinstance(v, str): |
| return [item.strip() for item in v.split(',')] |
| return v |
|
|
| try: |
| validated_data = MedicalRecord(**response) |
| return validated_data.dict() |
| except ValidationError as e: |
| return response |
|
|
|
|
|
|
| |
| def predict_Llama(model, tokenizer, text): |
| inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda") |
|
|
| try: |
| outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True) |
| extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| response = extracted_entities.split("### Response:", 1)[-1].strip() |
| response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)} |
|
|
| validated_response = validate_medical_record(response_dict) |
|
|
| return validated_response |
| except Exception as e: |
| print(f"Error during Llama prediction: {str(e)}") |
| return {} |
|
|
|
|
| |
| def process_audio(audio): |
| if isinstance(audio, str): |
| with open(audio, 'rb') as f: |
| audio_bytes = f.read() |
| else: |
| audio_bytes = audio |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
| temp_audio.write(audio_bytes) |
| temp_audio.flush() |
| audio_path = temp_audio.name |
|
|
| transcription = transcribe_audio(audio_path) |
|
|
| person_schema = """{"Name": "","Age": "","Gender": ""}""" |
| person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema) |
|
|
| try: |
| person_entities = json.loads(person_entities_raw) |
| except json.JSONDecodeError as e: |
| return f"Error in NuExtract response: {str(e)}" |
|
|
| medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription) |
|
|
| return ( |
| person_entities.get("Name", ""), |
| person_entities.get("Age", ""), |
| person_entities.get("Gender", ""), |
| ", ".join(medical_entities.get("Symptoms", [])), |
| ", ".join(medical_entities.get("Diagnosis", [])), |
| ", ".join(medical_entities.get("Medical_history", [])), |
| ", ".join(medical_entities.get("Action_plan", [])) |
| ) |
|
|
|
|
|
|
| |
| def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan): |
| csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") |
| |
| with open(csv_file.name, mode='w', newline='') as file: |
| writer = csv.writer(file) |
| writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"]) |
| writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan]) |
| |
| return csv_file.name |
|
|
|
|
|
|
| |
| demo = gr.Interface( |
| fn=process_audio, |
| inputs=[ |
| gr.Audio(type="filepath") |
| ], |
| outputs=[ |
| gr.Textbox(label="Name"), |
| gr.Textbox(label="Age"), |
| gr.Textbox(label="Gender"), |
| gr.Textbox(label="Symptoms"), |
| gr.Textbox(label="Diagnosis"), |
| gr.Textbox(label="Medical History"), |
| gr.Textbox(label="Plan of Action"), |
| ], |
| title="Medical Diagnostic Form Assistant", |
| description="Upload an audio file or record audio to generate a medical diagnostic form." |
| ) |
|
|
| with demo: |
| download_button = gr.Button("Download CSV") |
| download_button.click( |
| fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan), |
| inputs=demo.output_components, |
| outputs=gr.File(label="Download CSV") |
| ) |
|
|
| demo.launch() |
|
|
|
|