Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import bs4 | |
| from langchain import hub | |
| from langchain_unstructured import UnstructuredLoader | |
| from langchain_core.documents import Document | |
| from typing_extensions import List, TypedDict | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langchain.chat_models import init_chat_model | |
| from langgraph.graph import END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_core.tools import tool | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| import getpass | |
| import os | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import base64 | |
| import json | |
| import re | |
| import pytesseract | |
| import cv2 | |
| import numpy as np | |
| from pdf2image import convert_from_path | |
| # ---------- SETUP ---------- | |
| if not os.environ.get("GOOGLE_API_KEY"): | |
| os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") | |
| llm = init_chat_model("gemini-1.5-flash", model_provider="google_genai") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="intfloat/multilingual-e5-large-instruct", | |
| model_kwargs={"device": 'cpu', "trust_remote_code": True} | |
| ) | |
| vector_store = InMemoryVectorStore(embeddings) | |
| prompt = hub.pull("rlm/rag-prompt") | |
| # ---------- RETRIEVAL TOOL ---------- | |
| def retrieve(query: str): | |
| """Retrieve information related to a query.""" | |
| retrieved_docs = vector_store.similarity_search(query, k=3) | |
| serialized = "\n\n".join( | |
| (f"Source: {doc.metadata}\nContent: {doc.page_content}") | |
| for doc in retrieved_docs | |
| ) | |
| return serialized, retrieved_docs | |
| # ---------- GRAPH FUNCTIONS FOR RAG ---------- | |
| def query_or_respond(state: MessagesState): | |
| """Generate tool call for retrieval or respond.""" | |
| llm_with_tools = llm.bind_tools([retrieve]) | |
| response = llm_with_tools.invoke(state["messages"]) | |
| return {"messages": [response]} | |
| tools = ToolNode([retrieve]) | |
| def generate(state: MessagesState): | |
| """Generate answer.""" | |
| recent_tool_messages = [] | |
| for message in reversed(state["messages"]): | |
| if message.type == "tool": | |
| recent_tool_messages.append(message) | |
| else: | |
| break | |
| tool_messages = recent_tool_messages[::-1] | |
| docs_content = "\n\n".join(doc.content for doc in tool_messages) | |
| print(f"retrieved docs: ", docs_content) | |
| system_message_content = ( | |
| "You are an assistant for question-answering tasks. " | |
| "Use the following pieces of retrieved context to answer " | |
| "the question. If you don't know the answer, say that you don't know. " | |
| "Use three sentences maximum and keep the answer concise." | |
| "\n\n" | |
| f"{docs_content}" | |
| ) | |
| conversation_messages = [ | |
| message | |
| for message in state["messages"] | |
| if message.type in ("human", "system") | |
| or (message.type == "ai" and not message.tool_calls) | |
| ] | |
| prompt = [SystemMessage(system_message_content)] + conversation_messages | |
| response = llm.invoke(prompt) | |
| return {"messages": [response]} | |
| # ---------- BUILD RAG GRAPH ---------- | |
| graph_builder = StateGraph(MessagesState) | |
| graph_builder.add_node(query_or_respond) | |
| graph_builder.add_node(tools) | |
| graph_builder.add_node(generate) | |
| graph_builder.set_entry_point("query_or_respond") | |
| graph_builder.add_conditional_edges( | |
| "query_or_respond", | |
| tools_condition, | |
| {END: END, "tools": "tools"}, | |
| ) | |
| graph_builder.add_edge("tools", "generate") | |
| graph_builder.add_edge("generate", END) | |
| rag_graph = graph_builder.compile() | |
| # ---------- FORM FILLING AGENTS ---------- | |
| class FormState(TypedDict): | |
| template_path: str | |
| source_path: str | |
| schema: dict | |
| filled_data: dict | |
| filled_image: str | |
| def find_bbox(file_path, prim_schema, alignment="down"): | |
| keys = list(prim_schema.keys()) | |
| img = cv2.imread(file_path) | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| data = pytesseract.image_to_data(img_rgb, output_type=pytesseract.Output.DICT) | |
| words = [] | |
| for i in range(len(data['text'])): | |
| if int(data['conf'][i]) > -1 and data['text'][i].strip(): | |
| words.append({ | |
| 'text': data['text'][i], | |
| 'left': data['left'][i], | |
| 'top': data['top'][i], | |
| 'width': data['width'][i], | |
| 'height': data['height'][i], | |
| 'conf': int(data['conf'][i]) | |
| }) | |
| words.sort(key=lambda w: (w['top'], w['left'])) | |
| lines = [] | |
| current_line = [] | |
| current_top = None if not words else words[0]['top'] | |
| for word in words: | |
| if current_line and abs(word['top'] - current_top) > 20: | |
| lines.append(current_line) | |
| current_line = [] | |
| current_line.append(word) | |
| current_top = word['top'] | |
| if current_line: | |
| lines.append(current_line) | |
| boxes = [] | |
| for line in lines: | |
| if line: | |
| full_text = ' '.join(w['text'] for w in line) | |
| left = min(w['left'] for w in line) | |
| top = min(w['top'] for w in line) | |
| right = max(w['left'] + w['width'] for w in line) | |
| bottom = max(w['top'] + w['height'] for w in line) | |
| boxes.append({ | |
| 'text': full_text, | |
| 'clean': full_text.lower().replace(" ", "").replace(":", ""), | |
| 'left': left, | |
| 'top': top, | |
| 'width': right - left, | |
| 'height': bottom - top | |
| }) | |
| boxes.sort(key=lambda b: (b['top'], b['left'])) | |
| schema = {} | |
| height, width = img.shape[:2] | |
| threshold = 40 | |
| for idx, box in enumerate(boxes): | |
| if box['clean'] in keys and box['clean'] not in schema: | |
| key = box['clean'] | |
| label_bbox_norm = [box['left'] / width, box['top'] / height, (box['left'] + box['width']) / width, (box['top'] + box['height']) / height] | |
| label_bbox_pixel = [box['left'], box['top'], box['left'] + box['width'], box['top'] + box['height']] | |
| input_bbox = None | |
| if alignment == "right": | |
| # Look for the next box to the right within the same line with tighter vertical alignment | |
| for next_idx in range(idx + 1, len(boxes)): | |
| next_box = boxes[next_idx] | |
| if (next_box['top'] >= box['top'] and next_box['top'] + next_box['height'] <= box['top'] + box['height'] + 10 and | |
| next_box['left'] > box['left'] + box['width'] and next_box['left'] < box['left'] + box['width'] + 300): | |
| input_bbox_norm = [next_box['left'] / width, next_box['top'] / height, (next_box['left'] + next_box['width']) / width, (next_box['top'] + next_box['height']) / height] | |
| input_bbox_pixel = [next_box['left'], next_box['top'], next_box['left'] + next_box['width'], next_box['top'] + next_box['height']] | |
| break | |
| else: # Default to "down" | |
| # Look for the next box below the key | |
| for next_idx in range(idx + 1, len(boxes)): | |
| next_box = boxes[next_idx] | |
| if next_box['top'] > box['top'] + box['height'] and abs(next_box['left'] - box['left']) < 50 and next_box['top'] - box['top'] < 100: | |
| input_bbox_norm = [next_box['left'] / width, next_box['top'] / height, (next_box['left'] + next_box['width']) / width, (next_box['top'] + next_box['height']) / height] | |
| input_bbox_pixel = [next_box['left'], next_box['top'], next_box['left'] + next_box['width'], next_box['top'] + next_box['height']] | |
| break | |
| if input_bbox is None: | |
| if alignment == "right": | |
| input_x = box['left'] + box['width'] + threshold | |
| input_y = box['top'] | |
| input_w = 200 # Adjusted to match typical input width in your image | |
| input_h = box['height'] | |
| else: # "down" | |
| input_x = box['left'] | |
| input_y = box['top'] + box['height'] + threshold | |
| input_w = box['width'] | |
| input_h = 20 | |
| input_bbox_norm = [input_x / width, input_y / height, (input_x + input_w) / width, (input_y + input_h) / height] | |
| input_bbox_pixel = [input_x, input_y, input_x + input_w, input_y + input_h] | |
| schema[key] = input_bbox_pixel | |
| return schema | |
| def convert_template_file(file_path): | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".pdf": | |
| images = convert_from_path(file_path, dpi=300) | |
| out_path = "template_converted.png" | |
| images[0].save(out_path, "PNG") | |
| return out_path | |
| else: | |
| raise ValueError("Unsupported template file format") | |
| def analyze_template(file_path: str) -> dict: | |
| if file_path.endswith("pdf"): | |
| file_path = convert_template_file(file_path) | |
| with open(file_path, "rb") as image_file: | |
| image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| message = { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Analyse the following form and return just the JSON containing keys and values present in the image" | |
| "If te corresponding value is not present keep it as None." | |
| "Keep in mind that the keys cannot contain any spaces and should be lowercase. The output should be json loadable. ", | |
| }, | |
| { | |
| "type": "image", | |
| "source_type": "base64", | |
| "data": image_data, | |
| "mime_type": "image/jpeg", | |
| }, | |
| ] | |
| } | |
| response = llm.invoke([message]) | |
| out = response.text() | |
| match = re.search(r"\{.*\}", out, re.DOTALL) | |
| if match: | |
| json_string = match.group(0) # clean JSON | |
| schema = json.loads(json_string) | |
| else: | |
| schema = {} | |
| updated_schema = find_bbox(file_path, schema) | |
| position_schema = updated_schema | |
| return position_schema | |
| def extract_values(file_path: str, schema: dict) -> dict: | |
| filled_schema = {} | |
| schema_keys = list(schema.keys()) | |
| if file_path.endswith("txt"): | |
| with open(file_path, "r") as f: | |
| text = f.read() | |
| instr = f"Extract values from the source text for the following fields: {schema_keys}.Return a JSON with keys and extracted values.\n source text: {text}" | |
| message = { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": instr, | |
| } | |
| ] | |
| } | |
| else: | |
| if file_path.endswith("pdf"): | |
| images = convert_from_path(file_path, dpi=300) | |
| out_path = "source_converted.png" | |
| images[0].save(out_path, "PNG") | |
| with open(out_path, "rb") as image_file: | |
| image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| if file_path.endswith("png") or file_path.endswith("jpg"): | |
| with open(file_path, "rb") as image_file: | |
| image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| text = f"Extract values from the image for the following fields: {schema_keys}.Return a JSON with keys and extracted values." | |
| message = { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": text, | |
| }, | |
| { | |
| "type": "image", | |
| "source_type": "base64", | |
| "data": image_data, | |
| "mime_type": "image/jpeg", | |
| }, | |
| ] | |
| } | |
| response = llm.invoke([message]) | |
| out = response.text() | |
| match = re.search(r"\{.*\}", out, re.DOTALL) | |
| if match: | |
| json_string = match.group(0) # clean JSON | |
| filled = json.loads(json_string) | |
| else: | |
| filled = {} | |
| for key in schema: | |
| if key in filled: | |
| filled_schema[key] = filled[key] | |
| return filled_schema | |
| def fill_template(state: FormState): | |
| template_path = state["template_path"] | |
| position = state["schema"] | |
| filled_data = state["filled_data"] | |
| img = cv2.imread(template_path) | |
| for key, bbox in position.items(): | |
| if key in filled_data: | |
| x1, y1, x2, y2 = bbox | |
| text = filled_data[key] | |
| # Position text inside the box (slightly padded) | |
| cv2.putText(img, text, (x1+5, y2-5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, | |
| (0, 0, 0), 2) | |
| filled_path = template_path.replace(".jpg", "_filled.jpg").replace(".png", "_filled.png") | |
| cv2.imwrite(filled_path, img) | |
| return filled_path | |
| def analyze_node(state: FormState): | |
| schema = analyze_template(state["template_path"]) | |
| return {"schema": schema} | |
| def extract_node(state: FormState): | |
| filled = extract_values(state["source_path"], state["schema"]) | |
| return {"filled_data": filled} | |
| def fill_node(state: FormState): | |
| filled_image = fill_template(state) | |
| return {"filled_image": filled_image} | |
| # ---------- BUILD FORM FILLING GRAPH ---------- | |
| form_graph_builder = StateGraph(FormState) | |
| form_graph_builder.add_node("analyze", analyze_node) | |
| form_graph_builder.add_node("extract", extract_node) | |
| form_graph_builder.add_node("fill", fill_node) | |
| form_graph_builder.add_edge(START, "analyze") | |
| form_graph_builder.add_edge("analyze", "extract") | |
| form_graph_builder.add_edge("extract", "fill") | |
| form_graph_builder.add_edge("fill", END) | |
| form_graph = form_graph_builder.compile() | |
| # ---------- GRADIO APP FUNCTIONS ---------- | |
| def process_doc(file): | |
| loader = UnstructuredLoader( | |
| file_path=file.name, | |
| extract_images_in_pdf=True, | |
| languages=['ml', 'en'] | |
| ) | |
| docs = loader.load() | |
| vector_store.add_documents(documents=docs) | |
| return "β Document processed successfully! You can now ask questions." | |
| def ask_question(query, history): | |
| state = {"messages": [HumanMessage(content=query)]} | |
| response_text = "" | |
| for step in rag_graph.stream(state, stream_mode="values"): | |
| response_text = step["messages"][-1].content | |
| history.append((query, response_text)) | |
| return history, "" | |
| def process_form_filling(source, template): | |
| if not source or not template: | |
| return {"error": "Please upload both source and template files."}, None | |
| state = { | |
| "template_path": template.name, | |
| "source_path": source.name, | |
| "schema": {}, | |
| "filled_data": {}, | |
| "filled_image": "" | |
| } | |
| result = form_graph.invoke(state) | |
| return result["filled_data"], result["filled_image"] | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Inter"), gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"], | |
| font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"], | |
| ).set( | |
| body_background_fill="*neutral_50", | |
| body_background_fill_dark="*neutral_900", | |
| block_background_fill="*neutral_100", | |
| block_background_fill_dark="*neutral_800", | |
| block_border_width="1px", | |
| block_radius="8px", | |
| block_shadow="0 1px 3px rgba(0,0,0,0.1)", | |
| block_shadow_dark="0 1px 3px rgba(255,255,255,0.1)", | |
| button_primary_background_fill="*primary_500", | |
| button_primary_background_fill_hover="*primary_600", | |
| button_primary_text_color="white", | |
| button_secondary_background_fill="*neutral_200", | |
| button_secondary_background_fill_hover="*neutral_300", | |
| button_secondary_text_color="*neutral_800", | |
| input_background_fill="*neutral_50", | |
| input_background_fill_dark="*neutral_800", | |
| input_border_color="*neutral_200", | |
| input_border_color_dark="*neutral_700", | |
| panel_background_fill="*neutral_50", | |
| panel_background_fill_dark="*neutral_900", | |
| slider_color="*primary_500", | |
| ) | |
| with gr.Blocks(theme=theme, css=".gradio-container {max-width: 1200px !important; margin: auto;}") as demo: | |
| gr.Markdown( | |
| """ | |
| # π Multi-lingual Doc RAG and Form Filling System | |
| """, | |
| elem_classes="text-center" | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("π Document RAG"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload Document") | |
| upload_btn = gr.File(label="Select Document", file_types=[".pdf", ".txt", ".docx"], interactive=True) | |
| process_status = gr.Textbox(label="Processing Status", placeholder="Upload a document to start...", interactive=False) | |
| upload_btn.upload(process_doc, upload_btn, process_status) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Chat with Document") | |
| chatbot = gr.Chatbot(height=400, placeholder="Ask questions about your document here...") | |
| query = gr.Textbox(label="Your Question", placeholder="Type your question and press Enter...") | |
| query.submit(ask_question, [query, chatbot], [chatbot, query]) | |
| with gr.Tab("ποΈ Form Filling"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Upload Files") | |
| source_upload = gr.File(label="Source File (Information Source)", file_types=[".jpg", ".png", ".txt", ".pdf"], interactive=True) | |
| template_upload = gr.File(label="Template File (Form to Fill)", file_types=[".jpg", ".png"], interactive=True) | |
| fill_btn = gr.Button("Process and Fill Form", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Extracted Data") | |
| output_json = gr.JSON(label="Filled Form Data (JSON)") | |
| with gr.Column(): | |
| gr.Markdown("### Filled Form Preview") | |
| output_image = gr.Image(label="Filled Form Image", interactive=False) | |
| fill_btn.click(process_form_filling, [source_upload, template_upload], [output_json, output_image]) | |
| gr.Markdown( | |
| """ | |
| --- | |
| *Note: For form filling the system currently expects all required fields to be completed under their corresponding keys.* | |
| """, | |
| elem_classes="text-center" | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |