File size: 8,195 Bytes
1db7196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import json
import os
from openai import OpenAI

# --- CONFIGURATION ---
DATA_PATH = '/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/code/correction_evaluation_full_text_with_gs.json'
SAVE_DIR = '/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/correction_data/'
PROMPT_TEMPLATE_PATH = "/home/mshahidul/readctrl/prompts/syn_data_gen_diff_label_mod.txt"
API_FILE_PATH = "/home/mshahidul/api_new.json"

# --- INITIALIZATION ---
# Load API Key
with open(API_FILE_PATH, "r") as f:
    api_keys = json.load(f)
    client = OpenAI(api_key=api_keys["openai"])

# Load Prompt Template
with open(PROMPT_TEMPLATE_PATH, "r") as f:
    PROMPT_TEMPLATE = f.read()

def load_data():
    if os.path.exists(DATA_PATH):
        with open(DATA_PATH, 'r') as f:
            return json.load(f)
    return []

DATA = load_data()

# --- AI LOGIC ---
def call_ai_processor(index, full_text, gold_summary):
    """Calls GPT-5 (OpenAI API) and extracts the text for the current label."""
    try:
        item = DATA[index]
        target_label = item.get('ai_label') # e.g., "low_health_literacy"
        
        # Note: 'source_language' should ideally be in your JSON. 
        # Defaulting to English if not found.
        source_lang = item.get('language', 'English') 

        # Format the prompt
        prompt = (PROMPT_TEMPLATE
                  .replace("<<<FULL_TEXT>>>", full_text)
                  .replace("<<<SOURCE_LANGUAGE>>>", source_lang)
                  .replace("<<<GOLD_SUMMARY>>>", gold_summary)
                  .replace("<<<TARGET_LABEL>>>", target_label))
        # import ipdb; ipdb.set_trace()

        response = client.chat.completions.create(
            model="gpt-5-mini", # Change to "gpt-5" or specific model name when available
            messages=[{"role": "user", "content": prompt}],
            response_format={ "type": "json_object" }
        )

        content = json.loads(response.choices[0].message.content)
        
        # Extract only the text for the specific label we are currently editing
        # target_label usually matches the keys: low_health_literacy, etc.
        refined_text = content.get(target_label, "Error: Label not found in AI response.")
        return refined_text
    
    except Exception as e:
        return f"AI Error: {str(e)}"

# --- DATA HELPERS ---
def get_user_save_path(username):
    clean_name = "".join([c for c in username if c.isalpha() or c.isdigit()]).rstrip()
    return os.path.join(SAVE_DIR, f"final_corrected_{clean_name}.json")

def load_user_results(username):
    path = get_user_save_path(username)
    if os.path.exists(path):
        with open(path, 'r') as f:
            return json.load(f)
    return []

def get_record(index):
    if 0 <= index < len(DATA):
        item = DATA[index]
        ai_label = item.get('ai_label', '')
        ai_text = item.get('diff_label_texts', {}).get(ai_label, "Text not found")
        gold_summary = item.get('summary', '') # Added this for the AI prompt
        
        anno_info = (
            f"Plaban: {item.get('category_plaban')} (Rating: {item.get('rating_plaban')})\n"
            f"Mahi: {item.get('category_mahi')} (Rating: {item.get('rating_mahi')})\n"
            f"Shama: {item.get('category_shama')} (Rating: {item.get('rating_shama')})"
        )
        
        return (
            item.get('doc_id'),
            anno_info,
            ai_label.replace("_", " ").title(),
            item.get('fulltext'),
            ai_text,
            index,
            gold_summary
        )
    return None

def login_user(username):
    if not username or len(username.strip()) == 0:
        return gr.update(visible=True), gr.update(visible=False), 0, None, "", "", "", "", ""
    
    existing_data = load_user_results(username)
    start_index = len(existing_data)
    
    if start_index >= len(DATA):
        return gr.update(visible=False), gr.update(visible=True), start_index, "Finished!", "All caught up!", "No more data.", "No more data.", "", ""

    record = get_record(start_index)
    return (
        gr.update(visible=False), 
        gr.update(visible=True),  
        start_index,              
        record[0], record[1], record[2], record[3], record[4], record[6] 
    )

def save_and_next(username, index, corrected_text, is_ok):
    user_results = load_user_results(username)
    current_item = DATA[index]
    
    # If the user didn't type anything in manual_correction and hit "AI Text is OK", use original
    final_text = current_item.get('diff_label_texts', {}).get(current_item['ai_label']) if is_ok else corrected_text
    
    result_entry = {
        "doc_id": current_item['doc_id'],
        "ai_label": current_item['ai_label'],
        "status": "Approved" if is_ok else "Manually Corrected/AI Refined",
        "final_text": final_text,
        "original_ai_text": current_item.get('diff_label_texts', {}).get(current_item['ai_label'])
    }
    
    user_results.append(result_entry)
    
    with open(get_user_save_path(username), 'w') as f:
        json.dump(user_results, f, indent=4)
    
    next_index = index + 1
    if next_index < len(DATA):
        res = get_record(next_index)
        return list(res) + [""] 
    else:
        return [None, "Finished!", "Finished!", "No more data.", "No more data.", next_index, "No more data.", ""]

# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 📝 AI Label Correction Interface (v2 with GPT-Refinement)")
    
    current_idx = gr.State(0)
    user_session = gr.State("")
    gold_summary_hidden = gr.State("") # To hold the summary for the AI prompt

    with gr.Row() as login_row:
        with gr.Column(scale=1):
            user_input = gr.Textbox(label="Enter Username to Resume", placeholder="e.g., Shahidul")
            btn_login = gr.Button("Start Annotation", variant="primary")

    with gr.Column(visible=False) as main_container:
        with gr.Row():
            with gr.Column(scale=1):
                doc_id_display = gr.Textbox(label="Document ID", interactive=False)
                ai_label_display = gr.Label(label="Target AI Label")
                annotator_stats = gr.Textbox(label="Human Annotator Ratings", lines=4, interactive=False)
            
            with gr.Column(scale=2):
                full_text_display = gr.Textbox(label="Source Full Text", lines=10, interactive=False)

        with gr.Row():
            with gr.Column():
                ai_generated_text = gr.Textbox(label="Original AI Text", lines=6, interactive=False)
            with gr.Column():
                manual_correction = gr.Textbox(label="AI Refinement / Manual Correction", placeholder="AI generated text will appear here...", lines=6)
                btn_ai_check = gr.Button("✨ Check & Refine through AI", variant="secondary")

        with gr.Row():
            btn_ok = gr.Button("✅ Original Text is OK", variant="primary")
            btn_fix = gr.Button("💾 Save Current Correction/AI Text", variant="stop")

    # --- LOGIC ---
    btn_login.click(
        fn=login_user,
        inputs=[user_input],
        outputs=[login_row, main_container, current_idx, doc_id_display, annotator_stats, ai_label_display, full_text_display, ai_generated_text, gold_summary_hidden]
    ).then(fn=lambda username: username, inputs=[user_input], outputs=[user_session])

    # AI Regeneration Logic
    btn_ai_check.click(
        fn=call_ai_processor,
        inputs=[current_idx, full_text_display, gold_summary_hidden],
        outputs=[manual_correction]
    )

    action_inputs = [user_session, current_idx, manual_correction]
    action_outputs = [doc_id_display, annotator_stats, ai_label_display, full_text_display, ai_generated_text, current_idx, gold_summary_hidden, manual_correction]

    btn_ok.click(
        fn=lambda user, idx, txt: save_and_next(user, idx, txt, True),
        inputs=action_inputs,
        outputs=action_outputs
    )

    btn_fix.click(
        fn=lambda user, idx, txt: save_and_next(user, idx, txt, False),
        inputs=action_inputs,
        outputs=action_outputs
    )

if __name__ == "__main__":
    demo.launch(share=True)