Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| from PIL import Image | |
| from database_operations import Neo4jDatabase | |
| from graph_visualization import visualize_graph | |
| from utils import extract_label_prefix, strip_keys, format_json, validate_json | |
| from models.gemini_image_to_json import fetch_gemini_response | |
| from models.openai_image_to_json import openaiprocess_image_to_json | |
| from any_to_image import pdf_to_images, process_image | |
| # Initialize Neo4j database | |
| db = Neo4jDatabase("bolt://localhost:7687", "neo4j", "password123") | |
| def dump_to_neo4j_with_confirmation(json_content, file_path, history, previous_states): | |
| if not file_path: | |
| return "No image uploaded or invalid file", history, previous_states, None | |
| try: | |
| json_data = json.loads(json_content) | |
| except json.JSONDecodeError: | |
| return "Invalid JSON data. Please check your input.", history, previous_states, None | |
| label_prefix = extract_label_prefix(file_path) | |
| if db.check_existing_graph(label_prefix): | |
| previous_state = db.get_graph_data(label_prefix) | |
| return f"A graph with label prefix '{label_prefix}' already exists in the database. Do you want to overwrite it?", history, previous_states, label_prefix | |
| else: | |
| json_data = strip_keys(json_data) | |
| db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) | |
| result = f"Data successfully dumped into the database with label prefix '{label_prefix}'." | |
| new_history = f"{history}\n[NEW ENTRY] {result}" if history else f"[NEW ENTRY] {result}" | |
| previous_states[label_prefix] = [] | |
| return result, new_history, previous_states, None | |
| def confirm_overwrite(confirmation, gradio_state, json_content, file_path, history, previous_states): | |
| if confirmation.lower() == 'yes': | |
| try: | |
| label_prefix = extract_label_prefix(file_path) | |
| previous_state = db.get_graph_data(label_prefix) | |
| # print(f'previous_state from the confirm_overwrite function: {previous_state}') | |
| # print(f'label_prefix from the confirm_overwrite function: {label_prefix}') | |
| # print(f'previouse_states from the confirm_overwrite function: {previous_states}') | |
| if label_prefix not in previous_states: | |
| previous_states[label_prefix] = [] | |
| previous_states[label_prefix].append(previous_state) | |
| else: | |
| previous_states[label_prefix].append(previous_state) | |
| if len(previous_states[label_prefix]) > 3: | |
| previous_states[label_prefix] = previous_states[label_prefix][-3:] | |
| db.delete_graph(label_prefix) | |
| json_data = json.loads(json_content) | |
| json_data = strip_keys(json_data) | |
| db.dump_to_neo4j(json_data['nodes'], json_data['edges'], label_prefix) | |
| result = f"Data successfully overwritten in the database with label prefix '{label_prefix}'." | |
| new_history = f"{history}\n[OVERWRITE] {result}" if history else f"[OVERWRITE] {result}" | |
| return result, new_history, previous_states, "" | |
| except json.JSONDecodeError: | |
| return "Invalid JSON data. Please check your input.", history, previous_states, "" | |
| else: | |
| return "Operation cancelled. The existing graph was not overwritten.", history, previous_states, "" | |
| def revert_last_action(history, previous_states): | |
| if not history: | |
| return "No actions to revert.", history, previous_states | |
| last_action = history.split('\n')[-1] | |
| label_prefix = last_action.split("'")[1] | |
| if label_prefix in previous_states and previous_states[label_prefix]: | |
| db.delete_graph(label_prefix) | |
| db.dump_to_neo4j(previous_states[label_prefix][-1]['nodes'], previous_states[label_prefix][-1]['edges'], label_prefix) | |
| new_history = history + f"\n[REVERT] Reverted overwrite of graph with label prefix '{label_prefix}'" | |
| previous_states[label_prefix].pop() | |
| return f"Reverted last action: {last_action}", new_history, previous_states | |
| elif label_prefix in previous_states and not previous_states[label_prefix]: | |
| db.delete_graph(label_prefix) | |
| new_history = history + f"\n[REVERT] Deleted newly added graph with label prefix '{label_prefix}'" | |
| del previous_states[label_prefix] | |
| return f"Reverted last action: {last_action}", new_history, previous_states | |
| else: | |
| return "Unable to revert the last action.", history, previous_states | |
| def update_graph_from_edited_json(json_content, physics_enabled): | |
| try: | |
| json_data = json.loads(json_content) | |
| json_data = strip_keys(json_data) | |
| validate_json(json_data) | |
| return visualize_graph(json_data, physics_enabled), "" | |
| except json.JSONDecodeError as e: | |
| return None, f"Invalid JSON format: {str(e)}" | |
| except ValueError as e: | |
| return None, f"Invalid graph structure: {str(e)}" | |
| except Exception as e: | |
| return None, f"An unexpected error occurred: {str(e)}" | |
| def fetch_kg(image_file_path, model_choice_state): | |
| if image_file_path: | |
| mind_map_image = Image.open(image_file_path) | |
| if model_choice_state == 'Gemini': | |
| print(f'model choice is gemini') | |
| kg_json_text = fetch_gemini_response(mind_map_image) | |
| elif model_choice_state == 'OpenAI': | |
| print(f'model choice is openai') | |
| kg_json_text = openaiprocess_image_to_json(mind_map_image) | |
| json_data = json.loads(kg_json_text) | |
| return format_json(json_data), "" | |
| return "", "No image uploaded or invalid file" | |
| def input_file_handler(file_path): | |
| if file_path: | |
| image_path, error = process_image(file_path) | |
| return image_path, error | |
| return "", "No image uploaded or invalid file" | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Image to Knowledge Graph Transformation") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload File", file_count="single", | |
| type="filepath", | |
| file_types=[".pdf", ".png", ".jpeg", ".jpg", ".heic"]) | |
| image_file = gr.Image(label="Input Image", type="filepath", visible=False) | |
| json_editor = gr.Textbox(label="Edit JSON", lines=15, placeholder="JSON data will appear here after image upload") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| CCW_rotate_button = gr.Button('Rotate Image Counter-Clockwise') | |
| CW_rotate_button = gr.Button('Rotate Image Clockwise') | |
| with gr.Column(): | |
| model_call = gr.Button('Transform Image into KG representation', scale=2) | |
| with gr.Row(): | |
| physics_button = gr.Checkbox(value=True, label="Enable Graph Physics") | |
| model_choice = gr.Radio(label="Select Model", choices=["OpenAI", "Gemini"], value="Gemini", interactive=True) | |
| graph_output = gr.HTML(label="Graph Output") | |
| error_output = gr.Textbox(label="Error Messages", interactive=False) | |
| update_button = gr.Button("Update Graph") | |
| dump_button = gr.Button("Dump to Neo4j") | |
| revert_button = gr.Button("Revert Last Action") | |
| history_block = gr.Textbox(label="History", placeholder="Graphs pushed to the Database", interactive=False, lines=5, max_lines=50) | |
| history_state = gr.State("") | |
| previous_states = gr.State({}) | |
| confirmation_output = gr.Textbox(label="Confirmation Message", visible=False, interactive=False) | |
| confirmation_input = gr.Textbox(label="Type 'yes' to confirm overwrite", visible=False, interactive=True) | |
| confirm_button = gr.Button("Confirm Overwrite", visible=False) | |
| #----------------------------------------- | |
| # Added 2 examples for this deployment only | |
| # examples_list = ["image_examples/image1.png", "image_examples/image2.png"] | |
| # # same full chain of events as the file.upload() below | |
| # def process_input(file): | |
| # # First, call input_file_handler | |
| # processed_file, error = input_file_handler(file) | |
| # # Then, update image visibility | |
| # visible_image, hidden_file = update_image_visibility(processed_file) | |
| # return processed_file, error, visible_image, hidden_file | |
| # example_component = gr.Examples(examples_list, inputs=file_input, fn=process_input, outputs=[image_file, error_output, image_file, file_input]) | |
| #------------------------------------------- | |
| file_input.upload( | |
| fn=input_file_handler, | |
| inputs=[file_input], | |
| outputs=[image_file, error_output] | |
| ).then( | |
| lambda image_file: ( | |
| gr.Image(value=image_file, visible=True), | |
| gr.File(visible=False) | |
| ), | |
| inputs=[image_file], | |
| outputs=[image_file, file_input] | |
| ) | |
| image_file.clear( | |
| lambda file_input, image_file: ( | |
| gr.File(visible=True), | |
| gr.Image(visible=False) | |
| ), | |
| inputs=[file_input, image_file], | |
| outputs=[file_input, image_file] | |
| ) | |
| def rotate_image_to_left(image_path): | |
| if image_path: | |
| image = Image.open(image_path) | |
| image = image.rotate(-90, expand=True) | |
| image.save(image_path) | |
| return image_path | |
| CW_rotate_button.click( | |
| fn=rotate_image_to_left, | |
| inputs=[image_file], | |
| outputs=[image_file] | |
| ) | |
| def rotate_image_to_right(image_path): | |
| if image_path: | |
| image = Image.open(image_path) | |
| image = image.rotate(90, expand=True) | |
| image.save(image_path) | |
| return image_path | |
| CCW_rotate_button.click( | |
| fn=rotate_image_to_right, | |
| inputs=[image_file], | |
| outputs=[image_file] | |
| ) | |
| dump_button.click( | |
| dump_to_neo4j_with_confirmation, | |
| inputs=[json_editor, image_file, history_state, previous_states], | |
| outputs=[confirmation_output, history_state, previous_states, gr.State()] | |
| ).then( | |
| lambda message, history, previous_states, label_prefix: ( | |
| gr.Textbox(value=message, visible=True), | |
| gr.Textbox(visible=True), | |
| gr.Button(visible=True), | |
| history, | |
| previous_states, | |
| label_prefix | |
| ), | |
| inputs=[confirmation_output, history_state, previous_states, gr.State()], | |
| outputs=[confirmation_output, confirmation_input, confirm_button, history_state, previous_states, gr.State()] | |
| ).then( | |
| lambda history: history, | |
| inputs=[history_state], | |
| outputs=[history_block] | |
| ) | |
| gr.on( | |
| triggers=[confirm_button.click, confirmation_input.submit], | |
| fn=confirm_overwrite, | |
| inputs=[confirmation_input, gr.State(), json_editor, image_file, history_state, previous_states], | |
| outputs=[confirmation_output, history_state, previous_states, confirmation_input] | |
| ).then( | |
| lambda confirmation_output, confirmation_input: ( | |
| gr.Textbox(value=confirmation_output, visible=True), | |
| gr.Textbox(value='', visible=False), | |
| gr.Button(visible=False) | |
| ), | |
| inputs=[confirmation_output, confirmation_input], | |
| outputs=[confirmation_output, confirmation_input, confirm_button] | |
| ).then( | |
| lambda history: history, | |
| inputs=[history_state], | |
| outputs=[history_block] | |
| ) | |
| revert_button.click( | |
| revert_last_action, | |
| inputs=[history_state, previous_states], | |
| outputs=[confirmation_output, history_state, previous_states] | |
| ).then( | |
| lambda confirmation_output: gr.Textbox(value=confirmation_output, visible=True), | |
| inputs=[confirmation_output], | |
| outputs=[confirmation_output] | |
| ).then( | |
| lambda history: history, | |
| inputs=[history_state], | |
| outputs=[history_block] | |
| ) | |
| update_button.click( | |
| update_graph_from_edited_json, | |
| inputs=[json_editor, physics_button], | |
| outputs=[graph_output, error_output] | |
| ) | |
| physics_button.change( | |
| update_graph_from_edited_json, | |
| inputs=[json_editor, physics_button], | |
| outputs=[graph_output, error_output] | |
| ) | |
| model_call.click( | |
| fn=fetch_kg, | |
| inputs=[image_file, model_choice], | |
| outputs=[json_editor, error_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |