Spaces:
Paused
Paused
| import os | |
| import torch | |
| import gradio as gr | |
| import requests | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM | |
| from peft import PeftModel, PeftConfig | |
| from textwrap import wrap, fill | |
| ## using Falcon 7b Instruct | |
| Falcon_API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct" | |
| hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
| HEADERS = {"Authorization": "Bearer {hf_token}"} | |
| def falcon_query(payload): | |
| response = requests.post(Falcon_API_URL, headers=HEADERS, json=payload) | |
| return response.json() | |
| def falcon_inference(input_text): | |
| payload = {"inputs": input_text} | |
| return falcon_query(payload) | |
| ## using Mistral | |
| Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1" | |
| def mistral_query(payload): | |
| response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload) | |
| return response.json() | |
| def mistral_inference(input_text): | |
| payload = {"inputs": input_text} | |
| return mistral_query(payload) | |
| # Functions to Wrap the Prompt Correctly | |
| def wrap_text(text, width=90): | |
| lines = text.split('\n') | |
| wrapped_lines = [fill(line, width=width) for line in lines] | |
| wrapped_text = '\n'.join(wrapped_lines) | |
| return wrapped_text | |
| class ChatbotInterface(): | |
| def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): | |
| self.name = name | |
| self.system_prompt = system_prompt | |
| self.chatbot = gr.Chatbot() | |
| self.chat_history = [] | |
| with gr.Row() as row: | |
| row.justify = "end" | |
| self.msg = gr.Textbox(scale=7) | |
| #self.msg.change(fn=, inputs=, outputs=) | |
| self.submit = gr.Button("Submit", scale=1) | |
| clear = gr.ClearButton([self.msg, self.chatbot]) | |
| chat_history = [] | |
| self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot]) | |
| def respond(self, msg, chatbot): | |
| raise NotImplementedError | |
| class GaiaMinimed(ChatbotInterface): | |
| def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): | |
| super().__init__(name, system_prompt) | |
| def respond(self, msg, history): | |
| formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:" | |
| input_ids = tokenizer.encode( | |
| formatted_input, | |
| return_tensors="pt", | |
| add_special_tokens=False | |
| ) | |
| response = peft_model.generate( | |
| input_ids=input_ids, | |
| max_length=500, | |
| use_cache=False, | |
| early_stopping=False, | |
| bos_token_id=peft_model.config.bos_token_id, | |
| eos_token_id=peft_model.config.eos_token_id, | |
| pad_token_id=peft_model.config.eos_token_id, | |
| temperature=0.4, | |
| do_sample=True | |
| ) | |
| response_text = tokenizer.decode(response[0], skip_special_tokens=True) | |
| self.chat_history.append([formatted_input, response_text]) | |
| return "", self.chat_history | |
| class FalconBot(ChatbotInterface): | |
| def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): | |
| super().__init__(name, system_prompt) | |
| def respond(self, msg, chatbot): | |
| falcon_response = falcon_inference(msg) | |
| falcon_output = falcon_response[0]["generated_text"] | |
| self.chat_history.append([msg, falcon_output]) | |
| return "", falcon_output | |
| class MistralBot(ChatbotInterface): | |
| def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): | |
| super().__init__(name, system_prompt) | |
| def respond(self, msg, chatbot): | |
| mistral_response = mistral_inference(msg) | |
| mistral_output = mistral_response[0]["generated_text"] | |
| self.chat_history.append([msg, mistral_output]) | |
| return "", mistral_output | |
| if __name__ == "__main__": | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Use the base model's ID | |
| base_model_id = "tiiuae/falcon-7b-instruct" | |
| model_directory = "Tonic/GaiaMiniMed" | |
| # Instantiate the Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left") | |
| # Specify the configuration class for the model | |
| model_config = AutoConfig.from_pretrained(base_model_id) | |
| # Load the PEFT model with the specified configuration | |
| peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config) | |
| peft_model = PeftModel.from_pretrained(peft_model, model_directory) | |
| with gr.Blocks() as demo: | |
| with gr.Row() as intro: | |
| gr.Markdown( | |
| """ | |
| ## MedChat | |
| Welcome to MedChat, a medical assistant chatbot! You can currently chat with three chatbots that are trained on the same medical dataset. | |
| If you want to compare the output of each model, click the submit to all button and see the magic happen! | |
| """ | |
| ) | |
| with gr.Row() as row: | |
| with gr.Column() as col1: | |
| with gr.Tab("GaiaMinimed") as gaia: | |
| gaia_bot = GaiaMinimed("GaiaMinimed") | |
| with gr.Column() as col2: | |
| with gr.Tab("MistralMed") as mistral: | |
| mistral_bot = MistralBot("MistralMed") | |
| with gr.Tab("Falcon-7B") as falcon7b: | |
| falcon_bot = FalconBot("Falcon-7B") | |
| gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg]) | |
| mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg]) | |
| falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg]) | |
| demo.launch() |