File size: 5,041 Bytes
dd40489 cdd1f84 dd40489 cdd1f84 dd40489 cdd1f84 dd40489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
from transformers import pipeline
import json
# Load model
print("Loading model...")
classifier = pipeline(
"text-classification",
model="JohnLicode/ethics-review-deberta",
device="cpu" # Use CPU for free tier
)
print("Model loaded!")
def classify_ethics(text: str, guideline_id: str = "", guideline_name: str = ""):
"""Classify single text for ethics guideline compliance."""
# Format input like training data
if guideline_id and guideline_name:
input_text = f"Guideline {guideline_id} {guideline_name}: {text}"
else:
input_text = text
# Truncate if too long
input_text = input_text[:1500]
# Get prediction
result = classifier(input_text)[0]
# Map labels
label = result['label']
if label == "LABEL_0":
label = "ADDRESSED"
elif label == "LABEL_1":
label = "NEEDS_REVISION"
return {
"label": label,
"score": round(result['score'], 4),
"input_preview": input_text[:100] + "..."
}
def classify_batch(batch_json: str):
"""
Classify multiple texts in a single API call for better performance.
Input: JSON string with format:
[
{"text": "...", "guideline_id": "1.1", "guideline_name": "Objectives"},
{"text": "...", "guideline_id": "3.2", "guideline_name": "Privacy"},
...
]
Output: JSON string with results for each input.
"""
try:
items = json.loads(batch_json)
except json.JSONDecodeError as e:
return json.dumps({"error": f"Invalid JSON: {str(e)}"})
if not isinstance(items, list):
return json.dumps({"error": "Input must be a JSON array"})
if len(items) > 50:
return json.dumps({"error": "Maximum 50 items per batch"})
# Prepare all inputs
formatted_inputs = []
for item in items:
text = item.get("text", "")
g_id = item.get("guideline_id", "")
g_name = item.get("guideline_name", "")
if g_id and g_name:
input_text = f"Guideline {g_id} {g_name}: {text}"
else:
input_text = text
formatted_inputs.append(input_text[:1500])
# Run batch inference (much faster than individual calls)
predictions = classifier(formatted_inputs)
# Format results
results = []
for pred in predictions:
label = pred['label']
if label == "LABEL_0":
label = "ADDRESSED"
elif label == "LABEL_1":
label = "NEEDS_REVISION"
results.append({
"label": label,
"score": round(pred['score'], 4)
})
return json.dumps(results)
# Create Gradio interface with both single and batch endpoints
with gr.Blocks(title="Ethics Review Classifier") as demo:
gr.Markdown("# Ethics Review Classifier")
gr.Markdown("Classify research proposal text against ethics guidelines. Returns ADDRESSED or NEEDS_REVISION.")
with gr.Tab("Single Classification"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Text to Analyze", lines=5, placeholder="Enter the text from research proposal...")
id_input = gr.Textbox(label="Guideline ID (optional)", placeholder="e.g., 1.1")
name_input = gr.Textbox(label="Guideline Name (optional)", placeholder="e.g., Objectives")
single_btn = gr.Button("Classify", variant="primary")
with gr.Column():
single_output = gr.JSON(label="Result")
single_btn.click(classify_ethics, inputs=[text_input, id_input, name_input], outputs=single_output)
gr.Examples(
examples=[
["The general objective is to develop an AI ethics review system. Specific objectives: 1) Create scanning module 2) Implement matching.", "1.1", "Objectives"],
["All participant data will be encrypted using AES-256 and stored securely.", "3.2", "Privacy and confidentiality"],
["The study explores innovative approaches.", "1.7", "Sampling design and size"],
],
inputs=[text_input, id_input, name_input],
)
with gr.Tab("Batch Classification (Fast)"):
gr.Markdown("**For API users:** Send up to 50 items in one request for faster processing.")
batch_input = gr.Textbox(
label="Batch Input (JSON Array)",
lines=10,
placeholder='[{"text": "...", "guideline_id": "1.1", "guideline_name": "Objectives"}, ...]'
)
batch_btn = gr.Button("Classify Batch", variant="primary")
batch_output = gr.Textbox(label="Batch Results (JSON)", lines=10)
batch_btn.click(classify_batch, inputs=[batch_input], outputs=batch_output)
# Launch with API enabled
demo.launch()
|