Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import shap | |
| import numpy as np | |
| import scipy as sp | |
| import torch | |
| import tensorflow as tf | |
| import transformers | |
| from transformers import pipeline | |
| from transformers import RobertaTokenizer, RobertaModel | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import matplotlib.pyplot as plt | |
| import sys | |
| import csv | |
| csv.field_size_limit(sys.maxsize) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load models and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") | |
| model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) | |
| # Build a pipeline object for predictions | |
| # Note: return_all_scores is deprecated, use top_k=None instead | |
| pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # Added device=device for consistency | |
| # SHAP explainer | |
| # Check SHAP documentation for potential changes in Explainer initialization if issues arise. | |
| explainer = shap.Explainer(pred) | |
| # NER pipeline | |
| ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
| ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
| ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu | |
| # def adr_predict(x): | |
| def adr_predict(x): | |
| # Ensure input is treated as a string | |
| text_input = str(x).lower() | |
| encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device | |
| output = model(**encoded_input) | |
| # Use torch.softmax instead of tf.nn.softmax for consistency with torch model | |
| scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy | |
| # SHAP values - The explainer might need the raw prediction output or a wrapped function | |
| # that returns the model output in a format SHAP expects. | |
| # The current usage `explainer([text_input])` assumes the explainer can directly | |
| # handle the text input and the pipeline output structure. This might need adjustment | |
| # based on the SHAP version and how it interfaces with Hugging Face pipelines. | |
| # If explainer([text_input]) doesn't work, you might need to create a wrapper function | |
| # like `def f(text): return pred(text)` and pass `f` to shap.Explainer. | |
| # Also, shap.plots.text might have changes in how it's called or its return value. | |
| try: | |
| shap_values = explainer([text_input]) | |
| # Assuming shap_values structure is compatible and shap.plots.text works as before | |
| # You might need to explicitly handle the expected output format from the pipeline | |
| # which, with top_k=None, is a list of dictionaries. SHAP expects a consistent | |
| # output format from the prediction function. | |
| # If the pipeline output with top_k=None is different, the explainer might fail. | |
| # Let's assume for now the explainer can handle the output format. | |
| # shap.plots.text often returns a matplotlib figure or renders directly. | |
| # To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML). | |
| # The display=False might prevent direct rendering, but check if it returns a string representation. | |
| # If not, you'll need to generate an HTML/SVG string from the plot. | |
| # As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation. | |
| local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment | |
| except Exception as e: | |
| print(f"SHAP explanation failed: {e}") | |
| local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback | |
| # NER processing | |
| try: | |
| res = ner_pipe(text_input) | |
| entity_colors = { | |
| 'Severity': 'red', | |
| 'Sign_symptom': 'green', | |
| 'Medication': 'lightblue', | |
| 'Age': 'yellow', | |
| 'Sex':'yellow', | |
| 'Diagnostic_procedure':'gray', | |
| 'Biological_structure':'silver' | |
| } | |
| htext = "" | |
| prev_end = 0 | |
| # Sort entities by start position to build the highlighted text correctly | |
| res = sorted(res, key=lambda x: x['start']) | |
| for entity in res: | |
| start = entity['start'] | |
| end = entity['end'] | |
| word = text_input[start:end] # Extract original text segment | |
| entity_type = entity['entity_group'] | |
| color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color | |
| # Append text before the entity | |
| htext += f"{text_input[prev_end:start]}" | |
| # Append the highlighted entity | |
| htext += f"<mark style='background-color:{color};'>{word}</mark>" | |
| prev_end = end | |
| # Append any remaining text after the last entity | |
| htext += text_input[prev_end:] | |
| except Exception as e: | |
| print(f"NER processing failed: {e}") | |
| htext = "<p>NER processing not available.</p>" # Provide a fallback | |
| # The original code returns a tuple of results. Gradio's click function expects | |
| # the number of returned values to match the number of output components. | |
| # The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext | |
| # The output components are: label, local_plot, htext | |
| # The score output for the label component should be a dictionary as expected by gr.Label | |
| label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} | |
| return label_output, local_plot, htext | |
| def main(prob1): | |
| # The main function now directly calls adr_predict and returns its results. | |
| # This matches the expected signature for a Gradio interface function. | |
| return adr_predict(prob1) | |
| title = "Welcome to **ADR Detector** 🪐" | |
| description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" | |
| # Use the 'with' syntax for Blocks | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"## {title}") | |
| gr.Markdown(description1) | |
| gr.Markdown("""---""") | |
| # Define input and output components | |
| prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") | |
| # Output components matching the return values of the main function | |
| label = gr.Label(label = "Predicted Label") | |
| local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity | |
| htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity | |
| submit_btn = gr.Button("Analyze") | |
| # Use .click() on the button to define the interaction | |
| submit_btn.click( | |
| fn=main, # The function to call | |
| inputs=[prob1], # The input components | |
| outputs=[label, local_plot, htext], # The output components | |
| api_name="adr" # Keep the api_name if you intend to use the API | |
| ) | |
| # Examples section | |
| gr.Markdown("### Click on any of the examples below to see how it works:") | |
| # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly. | |
| # cache_examples is recommended for faster loading of examples. | |
| gr.Examples( | |
| examples=[ | |
| ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], | |
| ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] | |
| ], | |
| inputs=[prob1], | |
| outputs=[label, local_plot, htext], | |
| fn=main, # Provide the function to run for caching examples | |
| cache_examples=False, | |
| run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function. | |
| ) | |
| # Launch the demo | |
| demo.launch() |