Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| import datetime | |
| from pipeline import LitigationPipeline, MODELS | |
| from db import init_db, DocumentLog | |
| # --- DEFAULT PROMPTS --- | |
| DEFAULT_SUM_PROMPT = """You are an expert pharmaceutical litigation researcher. | |
| Summarize the provided clinical trial document or internal correspondence. | |
| Focus on: Adverse events, off-label promotion, and knowledge of side effects. | |
| Format output as Markdown. | |
| """ | |
| DEFAULT_VER_PROMPT = """You are a Senior QC Editor. | |
| Review the provided summary against the original text snippet. | |
| Rating Rubric: | |
| - 10: Perfect capture of adverse events and dates. | |
| - 1-9: Missing key dates or hallucinations. | |
| Output format STRICTLY as: | |
| SCORE: [0-10]/10 | |
| FEEDBACK: [Specific critique here] | |
| """ | |
| # --- UI LOGIC --- | |
| def run_pipeline(files, api_key, model, context_limit, retries, sum_prompt, ver_prompt): | |
| if not files: | |
| return "β οΈ **Error**: No files uploaded", pd.DataFrame() | |
| if not api_key: | |
| return "β οΈ **Error**: Please enter an API Key", pd.DataFrame() | |
| try: | |
| # Initialize Pipeline | |
| pipeline = LitigationPipeline(api_key, model, context_limit) | |
| # Run Batch | |
| results_summary = pipeline.process_batch(files, sum_prompt, ver_prompt, retries) | |
| # Fetch Logs | |
| logs_df = pipeline.get_logs() | |
| # Create detailed status message | |
| total_docs = len(files) | |
| successful = len([r for r in results_summary if r[1]=='SUCCESS']) | |
| failed = total_docs - successful | |
| status_msg = f"""### β Batch Processing Complete | |
| **Summary:** | |
| - **Total Documents**: {total_docs} | |
| - **β Successful**: {successful} | |
| - **β Failed**: {failed} | |
| {f'β οΈ **Warning**: {failed} document(s) failed processing. Check logs below for details.' if failed > 0 else 'π All documents processed successfully!'} | |
| """ | |
| return status_msg, logs_df | |
| except ValueError as e: | |
| # Handle configuration errors (invalid API key, model, etc.) | |
| error_msg = f"""### β Configuration Error | |
| **Error Details:** | |
| ``` | |
| {str(e)} | |
| ``` | |
| **Troubleshooting:** | |
| - Check that your API key is valid | |
| - Verify the selected model is supported | |
| - Ensure all configuration values are correct | |
| """ | |
| return error_msg, pd.DataFrame() | |
| except Exception as e: | |
| # Handle unexpected errors | |
| error_msg = f"""### β Unexpected Error | |
| **Error Details:** | |
| ``` | |
| {str(e)} | |
| ``` | |
| **What to do:** | |
| - Check the console logs for more details | |
| - Verify your PDF files are not corrupted | |
| - Try processing fewer files at once | |
| - Contact support if the issue persists | |
| """ | |
| return error_msg, pd.DataFrame() | |
| def get_initial_logs(): | |
| """Fetch initial logs to populate the UI.""" | |
| db_session = init_db() | |
| query = db_session.query(DocumentLog).order_by(DocumentLog.timestamp.desc()).limit(50) | |
| df = pd.read_sql(query.statement, db_session.bind) | |
| # Ensure columns match the expected headers in the UI | |
| df_display = df.rename(columns={ | |
| "verification_score": "score", | |
| "verifier_feedback": "feedback" | |
| }) | |
| return df_display | |
| def export_logs_csv(): | |
| """Export logs to CSV file.""" | |
| db_session = init_db() | |
| query = db_session.query(DocumentLog).order_by(DocumentLog.timestamp.desc()) | |
| df = pd.read_sql(query.statement, db_session.bind) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"litigation_logs_{timestamp}.csv" | |
| df.to_csv(filename, index=False) | |
| return filename | |
| def export_logs_json(): | |
| """Export logs to JSON file.""" | |
| db_session = init_db() | |
| query = db_session.query(DocumentLog).order_by(DocumentLog.timestamp.desc()) | |
| df = pd.read_sql(query.statement, db_session.bind) | |
| # Convert datetime to string for JSON serialization | |
| df['timestamp'] = df['timestamp'].astype(str) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"litigation_logs_{timestamp}.json" | |
| with open(filename, 'w') as f: | |
| json.dump(df.to_dict(orient='records'), f, indent=2) | |
| return filename | |
| # --- GRADIO LAYOUT --- | |
| with gr.Blocks(title="Pharma Litigation AI Researcher") as demo: | |
| gr.Markdown("## βοΈ Pharmaceutical Litigation Research Dashboard") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("Upload batch PDFs, configure the verifier loop, and extract litigation-critical insights.") | |
| # Prompt Engineering | |
| with gr.Accordion("π Prompt Engineering (Expand to Edit)", open=False): | |
| sum_prompt_box = gr.Textbox(label="Summarizer Prompt", value=DEFAULT_SUM_PROMPT, lines=5) | |
| ver_prompt_box = gr.Textbox(label="Verifier Prompt", value=DEFAULT_VER_PROMPT, lines=5) | |
| # File Upload | |
| file_uploader = gr.File(label="Upload Discovery Documents (PDF)", file_count="multiple", type="filepath") | |
| # Configuration Panel | |
| gr.Markdown("---") | |
| gr.Markdown("### βοΈ Configuration") | |
| api_key_input = gr.Textbox(label="API Key", type="password", placeholder="api-key...") | |
| model_selector = gr.Dropdown(list(MODELS.keys()), label="Model", value=list(MODELS.keys())[0]) | |
| context_limit = gr.Number(label="Reset Context After (N docs)", value=5, precision=0) | |
| retry_limit = gr.Slider(minimum=0, maximum=5, value=2, step=1, label="Max Verification Retries") | |
| run_btn = gr.Button("π Start Analysis Batch", variant="primary") | |
| with gr.Column(scale=2): | |
| # Results Area | |
| with gr.Tabs(): | |
| with gr.TabItem("π Live Execution Logs"): | |
| status_output = gr.Markdown("Waiting for input...") | |
| log_table = gr.Dataframe(label="Processing Logs (Database)", | |
| headers=[" id", "filename", "status", "retry_count", "verification_score", "verifier_feedback", "final_summary", "model_used", "cost_estimate", "timestamp"],) | |
| # Export Buttons | |
| gr.Markdown("### π₯ Export Results") | |
| with gr.Row(): | |
| export_csv_btn = gr.Button("π Download as CSV", size="sm") | |
| export_json_btn = gr.Button("π Download as JSON", size="sm") | |
| csv_download = gr.File(label="CSV Download", visible=False) | |
| json_download = gr.File(label="JSON Download", visible=False) | |
| # Event Wiring | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[ | |
| file_uploader, | |
| api_key_input, | |
| model_selector, | |
| context_limit, | |
| retry_limit, | |
| sum_prompt_box, | |
| ver_prompt_box | |
| ], | |
| outputs=[status_output, log_table] | |
| ) | |
| # Export event handlers | |
| export_csv_btn.click( | |
| fn=export_logs_csv, | |
| inputs=[], | |
| outputs=[csv_download] | |
| ).then( | |
| lambda: gr.File(visible=True), | |
| outputs=[csv_download] | |
| ) | |
| export_json_btn.click( | |
| fn=export_logs_json, | |
| inputs=[], | |
| outputs=[json_download] | |
| ).then( | |
| lambda: gr.File(visible=True), | |
| outputs=[json_download] | |
| ) | |
| demo.load(get_initial_logs, None, log_table) | |
| if __name__ == "__main__": | |
| demo.launch() |