import gradio as gr from transformers import pipeline import json # Load model print("Loading model...") classifier = pipeline( "text-classification", model="JohnLicode/ethics-review-deberta", device="cpu" # Use CPU for free tier ) print("Model loaded!") def classify_ethics(text: str, guideline_id: str = "", guideline_name: str = ""): """Classify single text for ethics guideline compliance.""" # Format input like training data if guideline_id and guideline_name: input_text = f"Guideline {guideline_id} {guideline_name}: {text}" else: input_text = text # Truncate if too long input_text = input_text[:1500] # Get prediction result = classifier(input_text)[0] # Map labels label = result['label'] if label == "LABEL_0": label = "ADDRESSED" elif label == "LABEL_1": label = "NEEDS_REVISION" return { "label": label, "score": round(result['score'], 4), "input_preview": input_text[:100] + "..." } def classify_batch(batch_json: str): """ Classify multiple texts in a single API call for better performance. Input: JSON string with format: [ {"text": "...", "guideline_id": "1.1", "guideline_name": "Objectives"}, {"text": "...", "guideline_id": "3.2", "guideline_name": "Privacy"}, ... ] Output: JSON string with results for each input. """ try: items = json.loads(batch_json) except json.JSONDecodeError as e: return json.dumps({"error": f"Invalid JSON: {str(e)}"}) if not isinstance(items, list): return json.dumps({"error": "Input must be a JSON array"}) if len(items) > 50: return json.dumps({"error": "Maximum 50 items per batch"}) # Prepare all inputs formatted_inputs = [] for item in items: text = item.get("text", "") g_id = item.get("guideline_id", "") g_name = item.get("guideline_name", "") if g_id and g_name: input_text = f"Guideline {g_id} {g_name}: {text}" else: input_text = text formatted_inputs.append(input_text[:1500]) # Run batch inference (much faster than individual calls) predictions = classifier(formatted_inputs) # Format results results = [] for pred in predictions: label = pred['label'] if label == "LABEL_0": label = "ADDRESSED" elif label == "LABEL_1": label = "NEEDS_REVISION" results.append({ "label": label, "score": round(pred['score'], 4) }) return json.dumps(results) # Create Gradio interface with both single and batch endpoints with gr.Blocks(title="Ethics Review Classifier") as demo: gr.Markdown("# Ethics Review Classifier") gr.Markdown("Classify research proposal text against ethics guidelines. Returns ADDRESSED or NEEDS_REVISION.") with gr.Tab("Single Classification"): with gr.Row(): with gr.Column(): text_input = gr.Textbox(label="Text to Analyze", lines=5, placeholder="Enter the text from research proposal...") id_input = gr.Textbox(label="Guideline ID (optional)", placeholder="e.g., 1.1") name_input = gr.Textbox(label="Guideline Name (optional)", placeholder="e.g., Objectives") single_btn = gr.Button("Classify", variant="primary") with gr.Column(): single_output = gr.JSON(label="Result") single_btn.click(classify_ethics, inputs=[text_input, id_input, name_input], outputs=single_output) gr.Examples( examples=[ ["The general objective is to develop an AI ethics review system. Specific objectives: 1) Create scanning module 2) Implement matching.", "1.1", "Objectives"], ["All participant data will be encrypted using AES-256 and stored securely.", "3.2", "Privacy and confidentiality"], ["The study explores innovative approaches.", "1.7", "Sampling design and size"], ], inputs=[text_input, id_input, name_input], ) with gr.Tab("Batch Classification (Fast)"): gr.Markdown("**For API users:** Send up to 50 items in one request for faster processing.") batch_input = gr.Textbox( label="Batch Input (JSON Array)", lines=10, placeholder='[{"text": "...", "guideline_id": "1.1", "guideline_name": "Objectives"}, ...]' ) batch_btn = gr.Button("Classify Batch", variant="primary") batch_output = gr.Textbox(label="Batch Results (JSON)", lines=10) batch_btn.click(classify_batch, inputs=[batch_input], outputs=batch_output) # Launch with API enabled demo.launch()