| | import datetime |
| | |
| | import gspread |
| | import random |
| | import time |
| | import functools |
| | from gspread.exceptions import SpreadsheetNotFound, APIError |
| | from oauth2client.service_account import ServiceAccountCredentials |
| | import pandas as pd |
| | import json |
| | import gradio as gr |
| | import os |
| |
|
| | GSERVICE_ACCOUNT_INFO = { |
| | "type": "service_account", |
| | "project_id": "txagent", |
| | "private_key_id": "cc1a12e427917244a93faf6f19e72b589a685e65", |
| | "private_key": None, |
| | "client_email": "shanghua@txagent.iam.gserviceaccount.com", |
| | "client_id": "108950722202634464257", |
| | "auth_uri": "https://accounts.google.com/o/oauth2/auth", |
| | "token_uri": "https://oauth2.googleapis.com/token", |
| | "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", |
| | "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/shanghua%40txagent.iam.gserviceaccount.com", |
| | "universe_domain": "googleapis.com" |
| | } |
| | GSHEET_NAME = "TxAgent_data_collection" |
| |
|
| | GSheet_API_KEY = os.environ.get("GSheets_Shanghua_PrivateKey") |
| | if GSheet_API_KEY is None: |
| | print("GSheet_API_KEY not found in environment variables. Please set it.") |
| | else: |
| | GSheet_API_KEY = GSheet_API_KEY.replace("\\n", "\n") |
| | GSERVICE_ACCOUNT_INFO["private_key"] = GSheet_API_KEY |
| |
|
| | |
| | def exponential_backoff_gspread(max_retries=30, max_backoff_sec=64, base_delay_sec=1, target_exception=APIError): |
| | """ |
| | Decorator to implement exponential backoff for gspread API calls. |
| | |
| | Retries a function call if it raises a specific exception (defaults to APIError) |
| | that matches the Google Sheets API rate limit error (HTTP 429). |
| | |
| | Args: |
| | max_retries (int): Maximum number of retry attempts. |
| | max_backoff_sec (int): Maximum delay between retries in seconds. |
| | base_delay_sec (int): Initial delay in seconds for the first retry. |
| | target_exception (Exception): The specific exception type to catch. |
| | """ |
| | def decorator(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | retries = 0 |
| | while True: |
| | try: |
| | |
| | return func(*args, **kwargs) |
| | except target_exception as e: |
| | |
| | |
| | error_message = str(e) |
| | is_rate_limit_error = "[429]" in error_message and ( |
| | "Quota exceeded" in error_message or "Too Many Requests" in error_message |
| | ) |
| |
|
| | if is_rate_limit_error: |
| | retries += 1 |
| | if retries > max_retries: |
| | print(f"Max retries ({max_retries}) exceeded for {func.__name__}. Last error: {e}") |
| | raise e |
| |
|
| | |
| | backoff_delay = min(max_backoff_sec, base_delay_sec * (2 ** (retries - 1)) + random.uniform(0, 1)) |
| |
|
| | print( |
| | f"Rate limit hit for {func.__name__} (Attempt {retries}/{max_retries}). " |
| | f"Retrying in {backoff_delay:.2f} seconds. Error: {e}" |
| | ) |
| | time.sleep(backoff_delay) |
| | else: |
| | |
| | print(f"Non-rate-limit APIError encountered in {func.__name__}: {e}") |
| | raise e |
| | except Exception as e: |
| | |
| | print(f"An unexpected error occurred in {func.__name__}: {e}") |
| | raise e |
| | return wrapper |
| | return decorator |
| |
|
| | |
| | |
| | scope = [ |
| | "https://spreadsheets.google.com/feeds", |
| | "https://www.googleapis.com/auth/drive", |
| | ] |
| |
|
| | |
| | creds = ServiceAccountCredentials.from_json_keyfile_dict(GSERVICE_ACCOUNT_INFO, scope) |
| | client = gspread.authorize(creds) |
| |
|
| | @exponential_backoff_gspread(max_retries=30, max_backoff_sec=64) |
| | def read_sheet_to_df(custom_sheet_name=None, sheet_index=0): |
| | """ |
| | Read all data from a Google Sheet into a pandas DataFrame. |
| | |
| | Parameters: |
| | custom_sheet_name (str): The name of the Google Sheet to open. If None, uses GSHEET_NAME. |
| | sheet_index (int): Index of the worksheet within the spreadsheet (default is 0, the first sheet). |
| | |
| | Returns: |
| | pandas.DataFrame: DataFrame containing the sheet data, with the first row used as headers. |
| | """ |
| |
|
| | |
| | if custom_sheet_name is None: |
| | custom_sheet_name = GSHEET_NAME |
| |
|
| | |
| | try: |
| | spreadsheet = client.open(custom_sheet_name) |
| | except gspread.SpreadsheetNotFound: |
| | return None |
| |
|
| | |
| | try: |
| | worksheet = spreadsheet.get_worksheet(sheet_index) |
| | except IndexError: |
| | return None |
| |
|
| | |
| | data = worksheet.get_all_records() |
| |
|
| | |
| | df = pd.DataFrame(data) |
| |
|
| | return df |
| |
|
| | @exponential_backoff_gspread(max_retries=30, max_backoff_sec=64) |
| | def append_to_sheet(user_data=None, custom_row_dict=None, custom_sheet_name=None, add_header_when_create_sheet=False): |
| | """ |
| | Append a new row to a Google Sheet. If 'custom_row' is provided, append that row. |
| | Otherwise, append a default row constructed from the provided user_data. |
| | Ensures that each value is aligned with the correct column header. |
| | """ |
| | if custom_sheet_name is None: |
| | custom_sheet_name = GSHEET_NAME |
| | |
| | try: |
| | |
| | spreadsheet = client.open(custom_sheet_name) |
| | is_new = False |
| | except SpreadsheetNotFound: |
| | |
| | spreadsheet = client.create(custom_sheet_name) |
| | |
| | spreadsheet.share('shanghuagao@gmail.com', perm_type='user', role='writer') |
| | spreadsheet.share('rzhu@college.harvard.edu', perm_type='user', role='writer') |
| | is_new = True |
| |
|
| | print("Spreadsheet ID:", spreadsheet.id) |
| | |
| | sheet = spreadsheet.sheet1 |
| |
|
| | |
| | existing_values = sheet.get_all_values() |
| | is_empty = (existing_values == [[]]) |
| |
|
| | |
| | if (is_new or is_empty) and add_header_when_create_sheet: |
| | |
| | if custom_row_dict is not None: |
| | headers = list(custom_row_dict.keys()) |
| | else: |
| | headers = list(user_data.keys()) |
| | sheet.append_row(headers) |
| | else: |
| | |
| | headers = sheet.row_values(1) if sheet.row_count > 0 else [] |
| |
|
| | |
| | if custom_row_dict is not None: |
| | |
| | custom_row = [custom_row_dict.get(header, "") for header in headers] |
| | else: |
| | |
| | custom_row = [str(datetime.datetime.now()), user_data["question"], user_data["final_answer"], user_data["trace"]] |
| | |
| | |
| | sheet.append_row(custom_row) |
| |
|
| | def format_chat(response, tool_database_labels): |
| | chat_history = [] |
| | |
| | last_tool_calls = [] |
| |
|
| | for msg in response: |
| | if msg["role"] == "assistant": |
| | content = msg.get("content", "") |
| | |
| | last_tool_calls = json.loads(msg.get("tool_calls", "[]")) |
| | |
| | chat_history.append( |
| | gr.ChatMessage(role="assistant", content=content) |
| | ) |
| |
|
| | elif msg["role"] == "tool": |
| | |
| | for i, tool_call in enumerate(last_tool_calls): |
| | name = tool_call.get("name", "") |
| | args = tool_call.get("arguments", {}) |
| |
|
| | |
| | database_label = "" |
| | if name == "Tool_RAG": |
| | title = "🧰 Tool RAG" |
| | else: |
| | title = f"🛠️ {name}" |
| | for db_label, tool_list in tool_database_labels.items(): |
| | if name in tool_list: |
| | title = f"🛠️ {name}\n(**Info** {db_label} [Click to view])" |
| | database_label = " (" + db_label + ")" |
| | break |
| |
|
| | |
| | raw = msg.get("content", "") |
| | try: |
| | parsed = json.loads(raw) |
| | pretty = json.dumps(parsed) |
| | except json.JSONDecodeError: |
| | pretty = raw |
| |
|
| | |
| | |
| | |
| | chat_history.append( |
| | gr.ChatMessage( |
| | role="assistant", |
| | content=f"Tool Response{database_label}:\n{pretty}", |
| | metadata={ |
| | "title": title, |
| | "log": json.dumps(args), |
| | "status": 'done' |
| | } |
| | ) |
| | ) |
| |
|
| | |
| | last_tool_calls = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if chat_history: |
| | last_msg = chat_history[-1] |
| | if isinstance(last_msg.content, str) and "[FinalAnswer]" in last_msg.content: |
| | last_msg.content = last_msg.content.replace("[FinalAnswer]", "\n**Answer:**\n") |
| |
|
| |
|
| | final_answer_messages = [gr.ChatMessage(role="assistant", content=chat_history[-1].content.split("\n**Answer:**\n")[-1].strip())] |
| | assistant_count = sum(1 for msg in chat_history if msg.role == "assistant") |
| | if assistant_count == 1: |
| | |
| | reasoning_messages = [gr.ChatMessage(role="assistant", content="No reasoning was conducted.")] |
| | else: |
| | |
| | reasoning_messages = chat_history.copy() |
| |
|
| |
|
| | return final_answer_messages, reasoning_messages, chat_history |