Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import google.generativeai as genai | |
| import openai | |
| from collections import Counter | |
| from huggingface_hub import InferenceClient | |
| import re | |
| def api_check_msg(api_key, selected_model): | |
| res = validate_api_key(api_key, selected_model) | |
| return res["message"] | |
| def validate_api_key(api_key, selected_model): | |
| # Check if the API key is valid for GPT-3.5-Turbo | |
| if "GPT" in selected_model: | |
| url = "https://api.openai.com/v1/models" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}" | |
| } | |
| try: | |
| response = requests.get(url, headers=headers) | |
| if response.status_code == 200: | |
| return {"is_valid": True, "message": '<p style="color: green;">GPT API Key is valid!</p>'} | |
| else: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Status code: {response.status_code}</p>'} | |
| except requests.exceptions.RequestException as e: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Error: {e}</p>'} | |
| elif "Llama" in selected_model: | |
| url = "https://huggingface.co/api/whoami-v2" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}" | |
| } | |
| try: | |
| response = requests.get(url, headers=headers) | |
| if response.status_code == 200: | |
| return {"is_valid": True, "message": '<p style="color: green;"> Llama API Key is valid!</p>'} | |
| else: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Status code: {response.status_code}</p>'} | |
| except requests.exceptions.RequestException as e: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Error: {e}</p>'} | |
| elif "Gemini" in selected_model: | |
| try: | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| response = model.generate_content("Help me diagnose the patient.") | |
| return {"is_valid": True, "message": '<p style="color: green;">Gemini API Key is valid!</p>'} | |
| except Exception as e: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Invalid Google API Key. Error: {e}</p>'} | |
| def generate_text_chatgpt(key, prompt, temperature, top_p): | |
| openai.api_key = key | |
| response = openai.chat.completions.create( | |
| model="gpt-3.5-turbo-1106", | |
| messages=[{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient based on the symptoms they provided."}, | |
| {"role": "user", "content": prompt}], | |
| temperature=temperature, | |
| max_tokens=50, | |
| top_p=top_p, | |
| frequency_penalty=0 | |
| ) | |
| return response.choices[0].message.content | |
| def generate_text_gemini(key, prompt, temperature, top_p): | |
| genai.configure(api_key=key) | |
| generation_config = genai.GenerationConfig( | |
| max_output_tokens=len(prompt)+50, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| model = genai.GenerativeModel("gemini-1.5-flash", generation_config=generation_config) | |
| response = model.generate_content(prompt) | |
| return response.text | |
| def generate_text_llama(key, prompt, temperature, top_p): | |
| model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| client = InferenceClient(api_key=key) | |
| messages = [{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient."}, | |
| {"role": "user","content": prompt}] | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| max_tokens=len(prompt)+50, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| response = completion.choices[0].message.content | |
| if len(response) > len(prompt): | |
| return response[len(prompt):] | |
| return response | |
| def sanitize_outputs(outputs): | |
| sanitized_results = [] | |
| for output in outputs: | |
| output = output.replace("\n", " ") | |
| output = re.sub(r"(Diagnose:|Answer:)", "", output, flags=re.IGNORECASE).strip() | |
| diagnoses = ["Psoriasis", "Arthritis", "Bronchial Asthma", "Cervical spondylosis"] | |
| found_diagnoses = [disease for disease in diagnoses if disease in output] | |
| if found_diagnoses: | |
| sanitized_results.append(found_diagnoses[0]) | |
| else: | |
| sanitized_results.append("Unknown") # Handle case where no valid diagnosis is found | |
| return sanitized_results | |
| def diagnose(gpt_key, llama_key, gemini_key, top_p, temperature, symptoms): | |
| if symptoms: | |
| prompt = "Given the next set of symptoms, classify the diagnosis as one of the following: " | |
| prompt += "Psoriasis, Arthritis, Bronchial Asthma, Cervical spondylosis. Please only output the classified diagnosis and nothing after that." | |
| prompt += "Choose only one among the words Psoriasis, Arthritis, Bronchial Asthma or Cervical spondylosis" | |
| prompt += "Do not list the symptoms again in the response. Do not add any additional text. Do not attempt to explain your answer." | |
| prompt += symptoms | |
| prompt += "Your Diagnosis: []" | |
| gpt_message = generate_text_chatgpt(gpt_key, prompt, temperature, top_p) | |
| llama_message = generate_text_llama(llama_key, prompt, temperature, top_p) | |
| gemini_message = generate_text_gemini(gemini_key, prompt, temperature, top_p) | |
| outputs = [gpt_message, llama_message, gemini_message] | |
| outputs = sanitize_outputs(outputs) | |
| output_counts = Counter(outputs) | |
| majority_output, majority_count = output_counts.most_common(1)[0] | |
| confidence = int((majority_count / len(outputs)) * 100) | |
| return gpt_message, llama_message, gemini_message, majority_output, confidence | |
| else: | |
| return {"is_valid": False, "message": f'<p style="color: red;">Please add the symptoms data to start the ranking process.</p>'} | |
| def update_model_components(selected_model): | |
| model_map = { | |
| "GPT-3.5-Turbo": "GPT", | |
| "Llama-3": "Llama", | |
| "Gemini-1.5": "Gemini" | |
| } | |
| link_map = { | |
| "GPT-3.5-Turbo": "https://platform.openai.com/account/api-keys", | |
| "Llama-3": "https://hf.co/settings/tokens", | |
| "Gemini-1.5": "https://aistudio.google.com/apikey" | |
| } | |
| textbox_label = f"Please input the API key for your {model_map[selected_model]} model" | |
| button_value = f"Don't have an API key? Get one for the {model_map[selected_model]} model here." | |
| button_link = link_map[selected_model] | |
| return gr.update(label=textbox_label), gr.update(value=button_value, link=button_link) | |
| def toggle_button(symptoms_text, gpt_key, llama_key, gemini_key): | |
| if symptoms_text.strip() and validate_api_key(gpt_key, "GPT") and \ | |
| validate_api_key(llama_key, "Llama") and validate_api_key(gemini_key, "Gemini"): | |
| return gr.update(interactive=True) | |
| return gr.update(interactive=False) | |
| with gr.Blocks() as ui: | |
| with gr.Row(equal_height=500): | |
| with gr.Column(scale=1, min_width=300): | |
| gpt_key = gr.Textbox(label="Please input your GPT key", type="password") | |
| llama_key = gr.Textbox(label="Please input your Llama key", type="password") | |
| gemini_key = gr.Textbox(label="Please input your Gemini key", type="password") | |
| is_valid = False | |
| status_message = gr.HTML(label="Validation Status") | |
| gpt_key.input(fn=api_check_msg, inputs=[gpt_key, gr.Textbox(value="GPT", visible=False)], outputs=status_message) | |
| status_message = gr.HTML(label="Validation Status") | |
| llama_key.input(fn=api_check_msg, inputs=[llama_key, gr.Textbox(value="Llama", visible=False)], outputs=status_message) | |
| status_message = gr.HTML(label="Validation Status") | |
| gemini_key.input(fn=api_check_msg, inputs=[gemini_key, gr.Textbox(value="Gemini", visible=False)], outputs=status_message) | |
| gr.Markdown("### Don't have an LLM key? Get one through the below links.") | |
| gr.Button(value="OpenAi Key", link="https://platform.openai.com/account/api-keys") | |
| gr.Button(value="Meta Llama Key", link="https://platform.openai.com/account/api-keys") | |
| gr.Button(value="Gemini Key", link="https://platform.openai.com/account/api-keys") | |
| gr.ClearButton(gpt_key, llama_key, gemini_key, variant="primary") | |
| with gr.Column(scale=2, min_width=600): | |
| gr.Markdown("### Hello, Welcome to the GUI by Team #9. This is the ranking API.") | |
| temperature = gr.Slider(0.0, 1.0, value=0.7, step = 0.01, label="Temperature", info="Set the Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step = 0.05, label="top-p value", info="Set the sampling nucleus parameter") | |
| symptoms = gr.Textbox(label="Add the symptom data in the input to receive diagnosis") | |
| llm_btn = gr.Button(value="Diagnose Disease", variant="primary", elem_id="diagnose", interactive=False) | |
| symptoms.input(toggle_button, inputs=[symptoms, gpt_key, llama_key, gemini_key], outputs=llm_btn) | |
| with gr.Row(equal_height=200): | |
| with gr.Column(scale=1, min_width=150): | |
| majority_output = gr.Textbox(label="Majority Output", interactive=False, placeholder="Majority Output") | |
| with gr.Column(scale=1, min_width=150): | |
| confidence = gr.Textbox(label="Confidence Score (%)", interactive=False, placeholder="Confidence Score") | |
| with gr.Row(equal_height=200): | |
| with gr.Column(scale=1, min_width=66): | |
| gpt_message = gr.Textbox(label="GPT Output", interactive=False, placeholder="GPT Output") | |
| with gr.Column(scale=1, min_width=66): | |
| llama_message = gr.Textbox(label="LLaMA Output", interactive=False, placeholder="LLaMA Output") | |
| with gr.Column(scale=1, min_width=66): | |
| gemini_message = gr.Textbox(label="Gemini Output", interactive=False, placeholder="Gemini Output") | |
| llm_btn.click(fn=diagnose, inputs=[gpt_key, llama_key, gemini_key, top_p, temperature, symptoms], | |
| outputs=[gpt_message, llama_message, gemini_message, majority_output, confidence], api_name="LLM_Comparator") | |
| ui.launch(share=True) |