import gradio as gr from transformers import pipeline from utils import * from datasets import load_dataset import json pipe = pipeline(model="raminass/m4", top_k=17, padding=True, truncation=True) # all = load_dataset("raminass/full_opinions_1994_2020") all = load_dataset("raminass/opinions-94-23") df = pd.DataFrame(all["train"]) choices = [] percuriams = df[df.type == "per_curiam"].copy() percuriams["case_name"] = percuriams["case_name"].apply(lambda x: x.strip()) percuriams = percuriams.sort_values(by="case_name", key=lambda x: x.str.lower()) for index, row in percuriams.iterrows(): if len(row["text"]) > 1000: choices.append((f"""{row["case_name"]}""", [row["text"], row["year"]])) with open("j_year.json", "r") as j: judges_by_year = json.loads(j.read()) judges_by_year = {int(k): v for k, v in judges_by_year.items()} # https://www.gradio.app/guides/controlling-layout def greet(opinion, judges_l): chunks = chunk_data(remove_citations(opinion))["text"].to_list() result = average_text(chunks, pipe, judges_l) return result[0], {k: round(v * 100, 2) for k, v in result[0].items()} def set_input(drop): return drop[0], drop[1], gr.Slider(visible=True) def update_year(year): return gr.CheckboxGroup( judges_by_year[year], value=judges_by_year[year], label="Select Justices", ) # Paragraph text paragraph_text = ( "One can refine these observations based on the prediction scores obtained in each case. " "As explained in the Methods, these scores do not correspond to probabilities but can be calibrated " "based on the cross-validation results. In particular, when we trained the algorithm we noticed that if " "the top prediction score is greater than 40% (50%, 60%), our accuracy in predicting the authoring justice " "increases to 93% (95%, 96%, respectively). Similarly, the original accuracy further improves to 95% when " "considering the top two predictions per opinion, rather than a single one. If the sum of the top two prediction " "scores exceeds 50%, the accuracy increases to 98%." ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=2): drop = gr.Dropdown( choices=sorted(choices), label="List of Per Curiam Opinions", info="Select a per curiam opinion from the dropdown menu and press the Predict Button", ) year = gr.Slider( 1994, 2023, step=1, label="Year", info="Select the year of the opinion if you manually paste the opinion below", ) exc_judg = gr.CheckboxGroup( judges_by_year[year.value], value=judges_by_year[year.value], label="Select Justices", info="Select justices to consider in prediction", ) opinion = gr.Textbox( label="Opinion", info="Paste opinion text here and press the Predict Button" ) with gr.Column(scale=1): with gr.Row(): clear_btn = gr.Button("Clear") greet_btn = gr.Button("Predict") op_level = gr.Label( num_top_classes=9, label="Predicted author of opinion" ) output_textbox = gr.Textbox( label="Output Text", buttons=["copy"], # shows a copy button ) info_textbox = gr.Textbox( value=paragraph_text, label="Additional Insights", interactive=False, # Makes the textbox read-only ) year.release( update_year, inputs=[year], outputs=[exc_judg], ) year.change( update_year, inputs=[year], outputs=[exc_judg], ) drop.select(set_input, inputs=drop, outputs=[opinion, year, year]) greet_btn.click( fn=greet, inputs=[opinion, exc_judg], outputs=[op_level, output_textbox], ) clear_btn.click( fn=lambda: [None, 1994, gr.Slider(visible=True), None, None], outputs=[opinion, year, year, drop, op_level], ) if __name__ == "__main__": demo.launch( # auth=("sc2024", "sc2024"), # auth_message="To request access, please email ronen3112@gmail.com", )