Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,19 +24,29 @@ import numpy as np
|
|
| 24 |
import gspread
|
| 25 |
from dotenv import load_dotenv
|
| 26 |
|
|
|
|
| 27 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
|
|
|
| 28 |
scopes = ["https://www.googleapis.com/auth/spreadsheets"]
|
| 29 |
creds = Credentials.from_service_account_file("credentials.json", scopes=scopes)
|
| 30 |
client = gspread.authorize(creds)
|
|
|
|
|
|
|
| 31 |
#environment
|
| 32 |
load_dotenv()
|
| 33 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 34 |
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
|
| 35 |
-
llm = ChatGroq(model="llama-3.1-70b-versatile")
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# Initialize Google Serper API wrapper
|
| 39 |
search = GoogleSerperAPIWrapper(serp_api_key=SERPER_API_KEY)
|
|
|
|
| 40 |
|
| 41 |
# Create the system and human messages for dynamic query processing
|
| 42 |
system_message_content = """
|
|
@@ -74,6 +84,7 @@ def perform_web_search(query, max_retries=3, delay=2):
|
|
| 74 |
time.sleep(delay)
|
| 75 |
st.error(f"Failed to perform web search for query '{query}' after {max_retries} retries.")
|
| 76 |
return "NaN"
|
|
|
|
| 77 |
def update_google_sheet(sheet_id, range_name, data):
|
| 78 |
try:
|
| 79 |
# Define the Google Sheets API scope
|
|
@@ -117,6 +128,7 @@ def get_llm_response(entity, query, web_results):
|
|
| 117 |
return cleaned_info
|
| 118 |
except Exception as e:
|
| 119 |
return "NaN"
|
|
|
|
| 120 |
# Retry logic for multiple web searches if necessary
|
| 121 |
def refine_answer_with_searches(entity, query, max_retries=3):
|
| 122 |
search_results = perform_web_search(query.format(entity=entity))
|
|
@@ -329,54 +341,203 @@ elif selected == "Extract Information":
|
|
| 329 |
|
| 330 |
column_selection = st.session_state["column_selection"]
|
| 331 |
entities_column = st.session_state["data"][column_selection]
|
| 332 |
-
|
| 333 |
-
st.
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
st.
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
st.
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
else:
|
| 368 |
st.warning("Please upload your data and define the query template.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
elif selected == "View & Download":
|
| 371 |
st.header("View & Download Results")
|
| 372 |
|
| 373 |
-
if "results" in st.session_state:
|
| 374 |
results_df = pd.DataFrame(st.session_state["results"])
|
| 375 |
st.write("### Results Preview")
|
| 376 |
|
| 377 |
-
# Display
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
| 379 |
|
|
|
|
| 380 |
download_option = st.selectbox(
|
| 381 |
"Select data to download:",
|
| 382 |
["All Results", "Extracted Information", "Web Results"]
|
|
@@ -396,36 +557,43 @@ elif selected == "View & Download":
|
|
| 396 |
mime="text/csv"
|
| 397 |
)
|
| 398 |
|
| 399 |
-
#
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
|
|
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
else:
|
| 431 |
st.warning("No results available to view. Please run the extraction process.")
|
|
|
|
| 24 |
import gspread
|
| 25 |
from dotenv import load_dotenv
|
| 26 |
|
| 27 |
+
|
| 28 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 29 |
+
|
| 30 |
+
#google sheet
|
| 31 |
scopes = ["https://www.googleapis.com/auth/spreadsheets"]
|
| 32 |
creds = Credentials.from_service_account_file("credentials.json", scopes=scopes)
|
| 33 |
client = gspread.authorize(creds)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
#environment
|
| 37 |
load_dotenv()
|
| 38 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 39 |
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
+
#session state variables
|
| 43 |
+
if "results" not in st.session_state:
|
| 44 |
+
st.session_state["results"] = []
|
| 45 |
+
|
| 46 |
+
|
| 47 |
# Initialize Google Serper API wrapper
|
| 48 |
search = GoogleSerperAPIWrapper(serp_api_key=SERPER_API_KEY)
|
| 49 |
+
llm = ChatGroq(model="llama-3.1-70b-versatile")
|
| 50 |
|
| 51 |
# Create the system and human messages for dynamic query processing
|
| 52 |
system_message_content = """
|
|
|
|
| 84 |
time.sleep(delay)
|
| 85 |
st.error(f"Failed to perform web search for query '{query}' after {max_retries} retries.")
|
| 86 |
return "NaN"
|
| 87 |
+
|
| 88 |
def update_google_sheet(sheet_id, range_name, data):
|
| 89 |
try:
|
| 90 |
# Define the Google Sheets API scope
|
|
|
|
| 128 |
return cleaned_info
|
| 129 |
except Exception as e:
|
| 130 |
return "NaN"
|
| 131 |
+
|
| 132 |
# Retry logic for multiple web searches if necessary
|
| 133 |
def refine_answer_with_searches(entity, query, max_retries=3):
|
| 134 |
search_results = perform_web_search(query.format(entity=entity))
|
|
|
|
| 341 |
|
| 342 |
column_selection = st.session_state["column_selection"]
|
| 343 |
entities_column = st.session_state["data"][column_selection]
|
| 344 |
+
|
| 345 |
+
col1, col2 = st.columns([2, 1])
|
| 346 |
+
with col1:
|
| 347 |
+
st.write("### Selected Entity Column:")
|
| 348 |
+
st.dataframe(entities_column, use_container_width=True)
|
| 349 |
+
|
| 350 |
+
with col2:
|
| 351 |
+
start_button = st.button("Start Extraction", type="primary", use_container_width=True)
|
| 352 |
+
|
| 353 |
+
results_container = st.empty()
|
| 354 |
+
|
| 355 |
+
if start_button:
|
| 356 |
+
with st.spinner("Extracting information..."):
|
| 357 |
+
progress_bar = st.progress(0)
|
| 358 |
+
progress_text = st.empty()
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
results = []
|
| 362 |
+
for i, selected_entity in enumerate(entities_column):
|
| 363 |
+
user_query = st.session_state["query_template"].replace("{entity}", str(selected_entity))
|
| 364 |
+
final_answer, search_results = refine_answer_with_searches(selected_entity, user_query)
|
| 365 |
+
results.append({
|
| 366 |
+
"Entity": selected_entity,
|
| 367 |
+
"Extracted Information": final_answer,
|
| 368 |
+
"Search Results": search_results
|
| 369 |
+
})
|
| 370 |
+
|
| 371 |
+
progress = (i + 1) / len(entities_column)
|
| 372 |
+
progress_bar.progress(progress)
|
| 373 |
+
progress_text.text(f"Processing {i+1}/{len(entities_column)} entities...")
|
| 374 |
+
|
| 375 |
+
st.session_state["results"] = results
|
| 376 |
+
|
| 377 |
+
progress_bar.empty()
|
| 378 |
+
progress_text.empty()
|
| 379 |
+
st.success("Extraction completed successfully!")
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
st.error(f"An error occurred during extraction: {str(e)}")
|
| 383 |
+
st.session_state.pop("results", None)
|
| 384 |
+
|
| 385 |
+
if "results" in st.session_state and st.session_state["results"]:
|
| 386 |
+
with results_container:
|
| 387 |
+
results = st.session_state["results"]
|
| 388 |
+
|
| 389 |
+
search_query = st.text_input("🔍 Search results", "")
|
| 390 |
+
|
| 391 |
+
tab1, tab2 = st.tabs(["Compact View", "Detailed View"])
|
| 392 |
+
|
| 393 |
+
with tab1:
|
| 394 |
+
found_results = False
|
| 395 |
+
for result in results:
|
| 396 |
+
if search_query.lower() in str(result["Entity"]).lower() or \
|
| 397 |
+
search_query.lower() in str(result["Extracted Information"]).lower():
|
| 398 |
+
found_results = True
|
| 399 |
+
with st.expander(f"📋 {result['Entity']}", expanded=False):
|
| 400 |
+
st.markdown("#### Extracted Information")
|
| 401 |
+
st.write(result["Extracted Information"])
|
| 402 |
+
|
| 403 |
+
if not found_results and search_query:
|
| 404 |
+
st.info("No results found for your search.")
|
| 405 |
+
|
| 406 |
+
with tab2:
|
| 407 |
+
found_results = False
|
| 408 |
+
for i, result in enumerate(results):
|
| 409 |
+
if search_query.lower() in str(result["Entity"]).lower() or \
|
| 410 |
+
search_query.lower() in str(result["Extracted Information"]).lower():
|
| 411 |
+
found_results = True
|
| 412 |
+
st.markdown(f"### Entity {i+1}: {result['Entity']}")
|
| 413 |
+
|
| 414 |
+
col1, col2 = st.columns(2)
|
| 415 |
+
|
| 416 |
+
with col1:
|
| 417 |
+
st.markdown("#### 📝 Extracted Information")
|
| 418 |
+
st.info(result["Extracted Information"])
|
| 419 |
+
|
| 420 |
+
with col2:
|
| 421 |
+
st.markdown("#### 🔍 Search Results")
|
| 422 |
+
st.warning(result["Search Results"])
|
| 423 |
+
|
| 424 |
+
st.divider()
|
| 425 |
+
|
| 426 |
+
if not found_results and search_query:
|
| 427 |
+
st.info("No results found for your search.")
|
| 428 |
else:
|
| 429 |
st.warning("Please upload your data and define the query template.")
|
| 430 |
+
|
| 431 |
+
elif selected == "Extract Information":
|
| 432 |
+
st.header("Extract Information")
|
| 433 |
+
|
| 434 |
+
if "query_template" in st.session_state and "data" in st.session_state:
|
| 435 |
+
st.write("### Using Query Template:")
|
| 436 |
+
st.code(st.session_state["query_template"])
|
| 437 |
|
| 438 |
+
column_selection = st.session_state["column_selection"]
|
| 439 |
+
entities_column = st.session_state["data"][column_selection]
|
| 440 |
+
|
| 441 |
+
col1, col2 = st.columns([2, 1])
|
| 442 |
+
with col1:
|
| 443 |
+
st.write("### Selected Entity Column:")
|
| 444 |
+
st.dataframe(entities_column, use_container_width=True)
|
| 445 |
+
|
| 446 |
+
with col2:
|
| 447 |
+
start_button = st.button("Start Extraction", type="primary", use_container_width=True)
|
| 448 |
+
|
| 449 |
+
results_container = st.empty()
|
| 450 |
+
|
| 451 |
+
if start_button:
|
| 452 |
+
with st.spinner("Extracting information..."):
|
| 453 |
+
progress_bar = st.progress(0)
|
| 454 |
+
progress_text = st.empty()
|
| 455 |
+
|
| 456 |
+
try:
|
| 457 |
+
results = []
|
| 458 |
+
for i, selected_entity in enumerate(entities_column):
|
| 459 |
+
user_query = st.session_state["query_template"].replace("{entity}", str(selected_entity))
|
| 460 |
+
final_answer, search_results = refine_answer_with_searches(selected_entity, user_query)
|
| 461 |
+
results.append({
|
| 462 |
+
"Entity": selected_entity,
|
| 463 |
+
"Extracted Information": final_answer,
|
| 464 |
+
"Search Results": search_results
|
| 465 |
+
})
|
| 466 |
+
|
| 467 |
+
progress = (i + 1) / len(entities_column)
|
| 468 |
+
progress_bar.progress(progress)
|
| 469 |
+
progress_text.text(f"Processing {i+1}/{len(entities_column)} entities...")
|
| 470 |
+
|
| 471 |
+
st.session_state["results"] = results
|
| 472 |
+
|
| 473 |
+
progress_bar.empty()
|
| 474 |
+
progress_text.empty()
|
| 475 |
+
st.success("Extraction completed successfully!")
|
| 476 |
+
|
| 477 |
+
except Exception as e:
|
| 478 |
+
st.error(f"An error occurred during extraction: {str(e)}")
|
| 479 |
+
st.session_state.pop("results", None)
|
| 480 |
+
|
| 481 |
+
if "results" in st.session_state and st.session_state["results"]:
|
| 482 |
+
with results_container:
|
| 483 |
+
results = st.session_state["results"]
|
| 484 |
+
|
| 485 |
+
search_query = st.text_input("🔍 Search results", "")
|
| 486 |
+
|
| 487 |
+
tab1, tab2 = st.tabs(["Compact View", "Detailed View"])
|
| 488 |
+
|
| 489 |
+
with tab1:
|
| 490 |
+
found_results = False
|
| 491 |
+
for result in results:
|
| 492 |
+
if search_query.lower() in str(result["Entity"]).lower() or \
|
| 493 |
+
search_query.lower() in str(result["Extracted Information"]).lower():
|
| 494 |
+
found_results = True
|
| 495 |
+
with st.expander(f"📋 {result['Entity']}", expanded=False):
|
| 496 |
+
st.markdown("#### Extracted Information")
|
| 497 |
+
st.write(result["Extracted Information"])
|
| 498 |
+
|
| 499 |
+
if not found_results and search_query:
|
| 500 |
+
st.info("No results found for your search.")
|
| 501 |
+
|
| 502 |
+
with tab2:
|
| 503 |
+
found_results = False
|
| 504 |
+
for i, result in enumerate(results):
|
| 505 |
+
if search_query.lower() in str(result["Entity"]).lower() or \
|
| 506 |
+
search_query.lower() in str(result["Extracted Information"]).lower():
|
| 507 |
+
found_results = True
|
| 508 |
+
st.markdown(f"### Entity {i+1}: {result['Entity']}")
|
| 509 |
+
|
| 510 |
+
col1, col2 = st.columns(2)
|
| 511 |
+
|
| 512 |
+
with col1:
|
| 513 |
+
st.markdown("#### 📝 Extracted Information")
|
| 514 |
+
st.info(result["Extracted Information"])
|
| 515 |
+
|
| 516 |
+
with col2:
|
| 517 |
+
st.markdown("#### 🔍 Search Results")
|
| 518 |
+
st.warning(result["Search Results"])
|
| 519 |
+
|
| 520 |
+
st.divider()
|
| 521 |
+
|
| 522 |
+
if not found_results and search_query:
|
| 523 |
+
st.info("No results found for your search.")
|
| 524 |
+
else:
|
| 525 |
+
st.warning("Please upload your data and define the query template.")
|
| 526 |
+
|
| 527 |
elif selected == "View & Download":
|
| 528 |
st.header("View & Download Results")
|
| 529 |
|
| 530 |
+
if "results" in st.session_state and st.session_state["results"]:
|
| 531 |
results_df = pd.DataFrame(st.session_state["results"])
|
| 532 |
st.write("### Results Preview")
|
| 533 |
|
| 534 |
+
# Display the results preview
|
| 535 |
+
if "Extracted Information" in results_df.columns and "Search Results" in results_df.columns:
|
| 536 |
+
st.dataframe(results_df.style.map(lambda val: 'background-color: #d3f4ff' if isinstance(val, str) else '', subset=["Extracted Information", "Search Results"]))
|
| 537 |
+
else:
|
| 538 |
+
st.warning("Required columns are missing in results data.")
|
| 539 |
|
| 540 |
+
# Download options
|
| 541 |
download_option = st.selectbox(
|
| 542 |
"Select data to download:",
|
| 543 |
["All Results", "Extracted Information", "Web Results"]
|
|
|
|
| 557 |
mime="text/csv"
|
| 558 |
)
|
| 559 |
|
| 560 |
+
# Option to update Google Sheets
|
| 561 |
+
update_option = st.selectbox(
|
| 562 |
+
"Do you want to update Google Sheets?",
|
| 563 |
+
["No", "Yes"]
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
if update_option == "Yes":
|
| 567 |
+
if 'sheet_id' not in st.session_state:
|
| 568 |
+
st.session_state.sheet_id = ''
|
| 569 |
+
if 'range_name' not in st.session_state:
|
| 570 |
+
st.session_state.range_name = ''
|
| 571 |
|
| 572 |
+
# Input fields for Google Sheets ID and Range
|
| 573 |
+
sheet_id = st.text_input("Enter Google Sheet ID", value=st.session_state.sheet_id)
|
| 574 |
+
range_name = st.text_input("Enter Range (e.g., 'Sheet1!A1')", value=st.session_state.range_name)
|
| 575 |
|
| 576 |
+
if sheet_id and range_name:
|
| 577 |
+
st.session_state.sheet_id = sheet_id
|
| 578 |
+
st.session_state.range_name = range_name
|
| 579 |
|
| 580 |
+
# Prepare data for update
|
| 581 |
+
data_to_update = [results_df.columns.tolist()] + results_df.values.tolist()
|
| 582 |
|
| 583 |
+
# Update Google Sheets button
|
| 584 |
+
if st.button("Update Google Sheet"):
|
| 585 |
+
try:
|
| 586 |
+
if '!' not in range_name:
|
| 587 |
+
st.error("Invalid range format. Please use the format 'SheetName!Range'.")
|
| 588 |
+
else:
|
| 589 |
+
sheet_name, cell_range = range_name.split('!', 1)
|
| 590 |
+
sheet = client.open_by_key(sheet_id).worksheet(sheet_name)
|
| 591 |
+
sheet.clear()
|
| 592 |
+
sheet.update(f"{cell_range}", data_to_update)
|
| 593 |
+
st.success("Data updated in the Google Sheet!")
|
| 594 |
+
except Exception as e:
|
| 595 |
+
st.error(f"Error updating Google Sheet: {e}")
|
| 596 |
+
else:
|
| 597 |
+
st.warning("Please enter both the Sheet ID and Range name before updating.")
|
| 598 |
else:
|
| 599 |
st.warning("No results available to view. Please run the extraction process.")
|