Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import time | |
| import os | |
| import shutil | |
| import pymupdf | |
| import json | |
| st.set_page_config( | |
| page_title="MGVG Grounding Demo", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| page_icon="logo.png" | |
| ) | |
| # --- Simple Authentication --- | |
| import streamlit as st | |
| import time | |
| # Define your valid credentials | |
| VALID_USERS = { | |
| "iitb": "iitb123", | |
| "badri": "badri123" | |
| } | |
| def login(): | |
| # Set a professional background for the whole app | |
| st.markdown( | |
| ''' | |
| <style> | |
| body, .stApp { | |
| background: linear-gradient(120deg, #e0eafc 0%, #cfdef3 100%) !important; | |
| } | |
| .login-box { | |
| background: #fff; | |
| padding: 2.5em 2em 2em 2em; | |
| border-radius: 16px; | |
| box-shadow: 0 4px 24px rgba(80, 120, 200, 0.12); | |
| min-width: 320px; | |
| max-width: 90vw; | |
| margin: auto; | |
| } | |
| </style> | |
| ''', unsafe_allow_html=True | |
| ) | |
| # Center the login box using columns | |
| col1, col2, col3 = st.columns([1,2,1]) | |
| with col2: | |
| # st.markdown('<div class="login-box">', unsafe_allow_html=True) | |
| # image at center | |
| st.image("logo.png", width=800, use_container_width=False) | |
| st.markdown('<h2 style="text-align:center; color:#2b6cb0; margin-bottom:1.5em;">🔒 Please log in to access the app</h2>', unsafe_allow_html=True) | |
| username = st.text_input("Username", key="login_username") | |
| password = st.text_input("Password", type="password", key="login_password") | |
| login_btn = st.button("Login") | |
| if login_btn: | |
| if username in VALID_USERS and VALID_USERS[username] == password: | |
| st.session_state["authenticated"] = True | |
| st.success("Login successful!") | |
| st.session_state["show_continue"] = True | |
| else: | |
| st.error("Invalid username or password") | |
| if st.session_state.get("show_continue", False): | |
| if st.button("Continue to App"): | |
| st.session_state["show_continue"] = False | |
| st.experimental_rerun() if hasattr(st, "experimental_rerun") else None | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| if "authenticated" not in st.session_state: | |
| st.session_state["authenticated"] = False | |
| if not st.session_state["authenticated"]: | |
| login() | |
| st.stop() | |
| # --- End Authentication --- | |
| # st.image("logo.png", width=250) | |
| from PIL import Image, ImageDraw | |
| import io | |
| # from st_audiorec import st_audiorec | |
| from surya.layout import LayoutPredictor | |
| from doctr.models import ocr_predictor | |
| from transformers import pipeline | |
| def get_layout_predictor(): | |
| return LayoutPredictor() | |
| def get_ocr_model(): | |
| return ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) | |
| def get_llm_model(device): | |
| return pipeline("text-generation", model="meta-llama/Meta-Llama-3.1-8B-Instruct", device=device) | |
| from predict_output import predict_output | |
| layout_predictor = get_layout_predictor() | |
| model = get_ocr_model() | |
| pipe = get_llm_model("cuda") | |
| print("Models loaded") | |
| # --- Placeholder function for demo --- | |
| def get_corresponding_bboxes(image, question): | |
| # Returns dummy bounding boxes and answer for demo | |
| # Each bbox: (x1, y1, x2, y2) | |
| w, h = image.size | |
| block_bboxes = [(w//8, h//8, w//2, h//2)] | |
| line_bboxes = [(w//4, h//4, w//2, h//3)] | |
| word_bboxes = [(w//3, h//3, w//2, h//2)] | |
| point_bboxes = [(w//2, h//2, w//2+5, h//2+5)] | |
| answer = "This is a demo answer." | |
| return block_bboxes, line_bboxes, word_bboxes, point_bboxes, answer | |
| # --- Helper to draw bboxes --- | |
| def draw_bboxes(image, bboxes, color): | |
| img = image.copy() | |
| # width proportional to the image size | |
| width = int(img.width/100) | |
| draw = ImageDraw.Draw(img) | |
| for bbox in bboxes: | |
| draw.rectangle(bbox, outline=color, width=width) | |
| return img | |
| def draw_points(image, bboxes, color): | |
| img = image.copy() | |
| width = int(img.width) | |
| draw = ImageDraw.Draw(img) | |
| for bbox in bboxes: | |
| # x1, y1, x2, y2 = bbox | |
| cx, cy = bbox[0], bbox[1] | |
| # r being relative to the image size | |
| r = int(img.width/100) | |
| draw.ellipse((cx-r, cy-r, cx+r, cy+r), outline=color, width=width, fill=color) | |
| return img | |
| # model_type = st.sidebar.checkbox("Use LLM Model", value=False) | |
| # model_type = "llm" if model_type else "inhouse" | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background: linear-gradient(135deg, #f8fafc 0%, #e0e7ef 100%); | |
| } | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| } | |
| .stButton>button { | |
| background-color: #4F8BF9; | |
| color: white; | |
| border-radius: 8px; | |
| font-size: 1.1rem; | |
| padding: 0.5em 2em; | |
| } | |
| .stTextInput>div>input { | |
| border-radius: 8px; | |
| border: 1px solid #4F8BF9; | |
| } | |
| .stFileUploader>div>div { | |
| border-radius: 8px; | |
| border: 2px dashed #4F8BF9; | |
| } | |
| .stAudio>audio { | |
| width: 100% !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| col_logo, col_title = st.columns([1, 8]) | |
| with col_logo: | |
| st.image("logo.png", width=180) | |
| with col_title: | |
| st.markdown("<h1 style='margin-bottom: 0;'>MGVG - Multi-Granular Visual Grounding</h1>", unsafe_allow_html=True) | |
| # List of quotes (HTML formatted) | |
| QUOTES = [ | |
| '''<div style="color: #2b6cb0; font-size: 1.3em; font-weight: 500; margin-bottom: 1em;"> | |
| "प्रत्यक्षं किं प्रमाणं?" <span style="font-size:0.9em; color:#444;">(<i>What better proof is there than direct perception?)</i></span> | |
| </div>''', | |
| '''<div style="color: #2b6cb0; font-size: 1.3em; font-weight: 500; margin-bottom: 1em;"> | |
| <i>"Truth is not told—it is seen."</i> | |
| </div>''' | |
| ] | |
| # Initialize session state for quote index and last update time | |
| if "quote_index" not in st.session_state: | |
| st.session_state.quote_index = 0 | |
| st.session_state.last_quote_time = time.time() | |
| # Check if 5 seconds have passed | |
| if time.time() - st.session_state.last_quote_time > 5: | |
| st.session_state.quote_index = (st.session_state.quote_index + 1) % len(QUOTES) | |
| st.session_state.last_quote_time = time.time() | |
| # Rerun the app to update the quote | |
| if hasattr(st, "experimental_rerun"): | |
| st.experimental_rerun() | |
| # Display the current quote | |
| st.markdown(QUOTES[st.session_state.quote_index], unsafe_allow_html=True) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("1. Upload Image or pdf document") | |
| image = "Not Uploaded" | |
| uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg", "pdf"]) | |
| if uploaded_file: | |
| current_dir = os.getcwd() | |
| temp_output_folder = os.path.join(current_dir, "temp_output_folder/") | |
| # delete the temp_output_folder | |
| if os.path.exists(temp_output_folder): | |
| shutil.rmtree(temp_output_folder) | |
| document_type = "image" | |
| if uploaded_file.type == "application/pdf": | |
| # save the uploaded file to a temp file | |
| temp_file_path = os.path.join(current_dir, "temp_file.pdf") | |
| # delete the temp_file_path | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| with open(temp_file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| if not os.path.exists(temp_output_folder): | |
| os.makedirs(temp_output_folder) | |
| # output_file = simple_counter_generator("page", ".jpg") | |
| # convert_from_path(document_path, output_folder=temp_output_folder, dpi=300, fmt='jpeg', jpegopt= jpg_options, output_file=output_file) | |
| pages = 0 | |
| doc = pymupdf.open(temp_file_path) # open document | |
| for page in doc: # iterate through the pages | |
| pages += 1 | |
| pix = page.get_pixmap() # render page to an image | |
| pix.save(f"{temp_output_folder}/{page.number}.png") | |
| if(pages == 1): | |
| document_type = "image" | |
| document_path = os.path.join(temp_output_folder, "0.png") | |
| uploaded_file = os.path.join(temp_output_folder, "0.png") | |
| image = Image.open(uploaded_file).convert("RGB") | |
| else: | |
| document_type = "pdf" | |
| # image = Image.open(uploaded_file).convert("RGB") | |
| if document_type == "image": | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Save uploaded image to a temp file for predict_output | |
| temp_file_path = "sample.png" | |
| image.save(temp_file_path) | |
| else: | |
| document_type = "pdf" | |
| document_path = uploaded_file.name | |
| image = "Uploaded PDF" | |
| # st.image(uploaded_file, caption="Uploaded PDF", use_container_width=True) | |
| else: | |
| image = "Not Uploaded" | |
| temp_output_folder = None | |
| st.image("https://placehold.co/400x300?text=Upload+Image", caption="Uploaded Image", use_container_width=True) | |
| st.subheader("2. Ask a question") | |
| question = st.text_input("Type your question here") | |
| # Add radio button for model selection | |
| model_type = st.radio( | |
| "Select Model Type:", | |
| options=["MGVG", "IndoDocs"], | |
| index=1, | |
| horizontal=True | |
| ) | |
| run_demo = st.button("Run Grounding Demo", use_container_width=True) | |
| # --- Output placeholders --- | |
| with col2: | |
| st.subheader("3. Visual Grounding Outputs") | |
| if image!="Not Uploaded" and (question): | |
| print(image) | |
| print(question) | |
| if run_demo and image!="Not Uploaded" and (question): | |
| # Use text input only | |
| q = question | |
| answer, block_bboxes, line_bboxes, word_bboxes, point_bboxes, current_page = predict_output( | |
| temp_file_path, q, pipe, layout_predictor, model, model_type, document_type | |
| ) | |
| # print(block_bboxes) | |
| # print(line_bboxes) | |
| # print(word_bboxes) | |
| # print(point_bboxes) | |
| print(answer) | |
| if(current_page != -1): | |
| image = Image.open(os.path.join(temp_output_folder, f"{current_page}.png")).convert("RGB") | |
| print("--------------------------------") | |
| print(image) | |
| block_img = draw_bboxes(image, block_bboxes, color="#4F8BF9") | |
| line_img = draw_bboxes(image, line_bboxes, color="#F97B4F") | |
| word_img = draw_bboxes(image, word_bboxes, color="#4FF9B2") | |
| point_img = draw_points(image, point_bboxes, color="#FFFF00") | |
| imgs = [block_img, line_img, word_img, point_img] | |
| labels = ["Block Level", "Line Level", "Word Level", "Point Level"] | |
| cols = st.columns(4) | |
| for i, (img, label) in enumerate(zip(imgs, labels)): | |
| with cols[i]: | |
| st.image(img, caption=label, use_container_width=True) | |
| answer_lines = answer.splitlines() | |
| st.markdown(""" | |
| <div style='background: #f1f5fa; border-radius: 10px; padding: 1em 2em; border: 1.5px solid #4F8BF9;'> | |
| <h4 style='color: #4F8BF9;'>Predicted Answer:</h4> | |
| <p style='font-size: 1.2em; color: #222;'>""" + "<br>".join(answer_lines) + """</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # --- Centered Save Results Button --- | |
| result_data = { | |
| "question": q, | |
| "answer": answer, | |
| "block_bboxes": block_bboxes, | |
| "line_bboxes": line_bboxes, | |
| "word_bboxes": word_bboxes, | |
| "point_bboxes": point_bboxes, | |
| "current_page": current_page | |
| } | |
| json_str = json.dumps(result_data, indent=2) | |
| col_left, col_center, col_right = st.columns([2, 3, 2]) | |
| with col_center: | |
| st.download_button( | |
| label="Save Results as JSON", | |
| data=json_str, | |
| file_name="grounding_results.json", | |
| mime="application/json" | |
| ) | |
| else: | |
| st.markdown(""" | |
| <div style='display: flex; gap: 2em; flex-wrap: wrap;'> | |
| <div style='flex: 1; min-width: 220px;'> | |
| <img src='https://placehold.co/220x180?text=Block+Level' style='width:100%; border-radius: 10px; border: 2px solid #4F8BF9;'> | |
| <p style='text-align:center; font-weight:600;'>Block Level</p> | |
| </div> | |
| <div style='flex: 1; min-width: 220px;'> | |
| <img src='https://placehold.co/220x180?text=Line+Level' style='width:100%; border-radius: 10px; border: 2px solid #4F8BF9;'> | |
| <p style='text-align:center; font-weight:600;'>Line Level</p> | |
| </div> | |
| <div style='flex: 1; min-width: 220px;'> | |
| <img src='https://placehold.co/220x180?text=Word+Level' style='width:100%; border-radius: 10px; border: 2px solid #4F8BF9;'> | |
| <p style='text-align:center; font-weight:600;'>Word Level</p> | |
| </div> | |
| <div style='flex: 1; min-width: 220px;'> | |
| <img src='https://placehold.co/220x180?text=Point+Level' style='width:100%; border-radius: 10px; border: 2px solid #4F8BF9;'> | |
| <p style='text-align:center; font-weight:600;'>Point Level</p> | |
| </div> | |
| </div> | |
| <br> | |
| <div style='background: #f1f5fa; border-radius: 10px; padding: 1em 2em; border: 1.5px solid #4F8BF9;'> | |
| <h4 style='color: #4F8BF9;'>Predicted Answer:</h4> | |
| <p style='font-size: 1.2em; color: #222;'>[Answer will appear here]</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |