| import gradio as gr |
| import json |
| import os |
| from openai import OpenAI |
|
|
| |
| DATA_PATH = '/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/code/correction_evaluation_full_text_with_gs.json' |
| SAVE_DIR = '/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/correction_data/' |
| PROMPT_TEMPLATE_PATH = "/home/mshahidul/readctrl/prompts/syn_data_gen_diff_label_mod.txt" |
| API_FILE_PATH = "/home/mshahidul/api_new.json" |
|
|
| |
| |
| with open(API_FILE_PATH, "r") as f: |
| api_keys = json.load(f) |
| client = OpenAI(api_key=api_keys["openai"]) |
|
|
| |
| with open(PROMPT_TEMPLATE_PATH, "r") as f: |
| PROMPT_TEMPLATE = f.read() |
|
|
| def load_data(): |
| if os.path.exists(DATA_PATH): |
| with open(DATA_PATH, 'r') as f: |
| return json.load(f) |
| return [] |
|
|
| DATA = load_data() |
|
|
| |
| def call_ai_processor(index, full_text, gold_summary): |
| """Calls GPT-5 (OpenAI API) and extracts the text for the current label.""" |
| try: |
| item = DATA[index] |
| target_label = item.get('ai_label') |
| |
| |
| |
| source_lang = item.get('language', 'English') |
|
|
| |
| prompt = (PROMPT_TEMPLATE |
| .replace("<<<FULL_TEXT>>>", full_text) |
| .replace("<<<SOURCE_LANGUAGE>>>", source_lang) |
| .replace("<<<GOLD_SUMMARY>>>", gold_summary) |
| .replace("<<<TARGET_LABEL>>>", target_label)) |
| |
|
|
| response = client.chat.completions.create( |
| model="gpt-5-mini", |
| messages=[{"role": "user", "content": prompt}], |
| response_format={ "type": "json_object" } |
| ) |
|
|
| content = json.loads(response.choices[0].message.content) |
| |
| |
| |
| refined_text = content.get(target_label, "Error: Label not found in AI response.") |
| return refined_text |
| |
| except Exception as e: |
| return f"AI Error: {str(e)}" |
|
|
| |
| def get_user_save_path(username): |
| clean_name = "".join([c for c in username if c.isalpha() or c.isdigit()]).rstrip() |
| return os.path.join(SAVE_DIR, f"final_corrected_{clean_name}.json") |
|
|
| def load_user_results(username): |
| path = get_user_save_path(username) |
| if os.path.exists(path): |
| with open(path, 'r') as f: |
| return json.load(f) |
| return [] |
|
|
| def get_record(index): |
| if 0 <= index < len(DATA): |
| item = DATA[index] |
| ai_label = item.get('ai_label', '') |
| ai_text = item.get('diff_label_texts', {}).get(ai_label, "Text not found") |
| gold_summary = item.get('summary', '') |
| |
| anno_info = ( |
| f"Plaban: {item.get('category_plaban')} (Rating: {item.get('rating_plaban')})\n" |
| f"Mahi: {item.get('category_mahi')} (Rating: {item.get('rating_mahi')})\n" |
| f"Shama: {item.get('category_shama')} (Rating: {item.get('rating_shama')})" |
| ) |
| |
| return ( |
| item.get('doc_id'), |
| anno_info, |
| ai_label.replace("_", " ").title(), |
| item.get('fulltext'), |
| ai_text, |
| index, |
| gold_summary |
| ) |
| return None |
|
|
| def login_user(username): |
| if not username or len(username.strip()) == 0: |
| return gr.update(visible=True), gr.update(visible=False), 0, None, "", "", "", "", "" |
| |
| existing_data = load_user_results(username) |
| start_index = len(existing_data) |
| |
| if start_index >= len(DATA): |
| return gr.update(visible=False), gr.update(visible=True), start_index, "Finished!", "All caught up!", "No more data.", "No more data.", "", "" |
|
|
| record = get_record(start_index) |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=True), |
| start_index, |
| record[0], record[1], record[2], record[3], record[4], record[6] |
| ) |
|
|
| def save_and_next(username, index, corrected_text, is_ok): |
| user_results = load_user_results(username) |
| current_item = DATA[index] |
| |
| |
| final_text = current_item.get('diff_label_texts', {}).get(current_item['ai_label']) if is_ok else corrected_text |
| |
| result_entry = { |
| "doc_id": current_item['doc_id'], |
| "ai_label": current_item['ai_label'], |
| "status": "Approved" if is_ok else "Manually Corrected/AI Refined", |
| "final_text": final_text, |
| "original_ai_text": current_item.get('diff_label_texts', {}).get(current_item['ai_label']) |
| } |
| |
| user_results.append(result_entry) |
| |
| with open(get_user_save_path(username), 'w') as f: |
| json.dump(user_results, f, indent=4) |
| |
| next_index = index + 1 |
| if next_index < len(DATA): |
| res = get_record(next_index) |
| return list(res) + [""] |
| else: |
| return [None, "Finished!", "Finished!", "No more data.", "No more data.", next_index, "No more data.", ""] |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 📝 AI Label Correction Interface (v2 with GPT-Refinement)") |
| |
| current_idx = gr.State(0) |
| user_session = gr.State("") |
| gold_summary_hidden = gr.State("") |
|
|
| with gr.Row() as login_row: |
| with gr.Column(scale=1): |
| user_input = gr.Textbox(label="Enter Username to Resume", placeholder="e.g., Shahidul") |
| btn_login = gr.Button("Start Annotation", variant="primary") |
|
|
| with gr.Column(visible=False) as main_container: |
| with gr.Row(): |
| with gr.Column(scale=1): |
| doc_id_display = gr.Textbox(label="Document ID", interactive=False) |
| ai_label_display = gr.Label(label="Target AI Label") |
| annotator_stats = gr.Textbox(label="Human Annotator Ratings", lines=4, interactive=False) |
| |
| with gr.Column(scale=2): |
| full_text_display = gr.Textbox(label="Source Full Text", lines=10, interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| ai_generated_text = gr.Textbox(label="Original AI Text", lines=6, interactive=False) |
| with gr.Column(): |
| manual_correction = gr.Textbox(label="AI Refinement / Manual Correction", placeholder="AI generated text will appear here...", lines=6) |
| btn_ai_check = gr.Button("✨ Check & Refine through AI", variant="secondary") |
|
|
| with gr.Row(): |
| btn_ok = gr.Button("✅ Original Text is OK", variant="primary") |
| btn_fix = gr.Button("💾 Save Current Correction/AI Text", variant="stop") |
|
|
| |
| btn_login.click( |
| fn=login_user, |
| inputs=[user_input], |
| outputs=[login_row, main_container, current_idx, doc_id_display, annotator_stats, ai_label_display, full_text_display, ai_generated_text, gold_summary_hidden] |
| ).then(fn=lambda username: username, inputs=[user_input], outputs=[user_session]) |
|
|
| |
| btn_ai_check.click( |
| fn=call_ai_processor, |
| inputs=[current_idx, full_text_display, gold_summary_hidden], |
| outputs=[manual_correction] |
| ) |
|
|
| action_inputs = [user_session, current_idx, manual_correction] |
| action_outputs = [doc_id_display, annotator_stats, ai_label_display, full_text_display, ai_generated_text, current_idx, gold_summary_hidden, manual_correction] |
|
|
| btn_ok.click( |
| fn=lambda user, idx, txt: save_and_next(user, idx, txt, True), |
| inputs=action_inputs, |
| outputs=action_outputs |
| ) |
|
|
| btn_fix.click( |
| fn=lambda user, idx, txt: save_and_next(user, idx, txt, False), |
| inputs=action_inputs, |
| outputs=action_outputs |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |