Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from tools import FreightAgent, EXAMPLE_QUERIES | |
| from utils import initialize_database | |
| from smolagents import CodeAgent, OpenAIServerModel | |
| import os | |
| from dotenv import load_dotenv | |
| from sql_data import sql_query, get_schema, get_csv_as_dataframe | |
| import pandas as pd | |
| from sqlalchemy import create_engine, text | |
| # Create database engine | |
| engine = create_engine("sqlite:///freights.db") | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize the database if it doesn't exist | |
| if not os.path.exists("freights.db"): | |
| csv_url = "https://huggingface.co/datasets/sasu-SpidR/fretmaritime/resolve/main/freights.csv" | |
| initialize_database(csv_url) | |
| # Create the main agent | |
| model_id = "gpt-4.1-mini" | |
| model = OpenAIServerModel(model_id=model_id, api_key=os.environ["OPENAI_API_KEY"]) | |
| agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model) | |
| def sql_query(query: str) -> str: | |
| """ | |
| Allows you to perform SQL queries on the freights table. Returns a string representation of the result. | |
| The table is named 'freights'. Its description is as follows: | |
| Columns: | |
| - departure: DateTime (Date and time of departure) | |
| - origin_port_locode: String (Origin port code) | |
| - origin_port_name: String (Name of the origin port) | |
| - destination_port: String (Destination port code) | |
| - destination_port_name: String (Name of the destination port) | |
| - dv20rate: Float (Rate for 20ft container in USD) | |
| - dv40rate: Float (Rate for 40ft container in USD) | |
| - currency: String (Currency of the rates) | |
| - inserted_on: DateTime (Date when the rate was inserted) | |
| Args: | |
| query: The query to perform. This should be correct SQL. | |
| Returns: | |
| A string representation of the result of the query. | |
| """ | |
| try: | |
| with engine.connect() as con: | |
| result = con.execute(text(query)) | |
| rows = [dict(row._mapping) for row in result] | |
| if not rows: | |
| return "Aucun résultat trouvé." | |
| # Convert to markdown table | |
| headers = list(rows[0].keys()) | |
| table = "| " + " | ".join(headers) + " |\n" | |
| table += "| " + " | ".join(["---" for _ in headers]) + " |\n" | |
| for row in rows: | |
| table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n" | |
| return table | |
| except Exception as e: | |
| return f"Error executing query: {str(e)}" | |
| def get_schema() -> str: | |
| """ | |
| Returns the schema of the freights table. | |
| """ | |
| return """ | |
| Table: freights | |
| Columns: | |
| - departure: DateTime (Date and time of departure) | |
| - origin_port_locode: String (Origin port code) | |
| - origin_port_name: String (Name of the origin port) | |
| - destination_port: String (Destination port code) | |
| - destination_port_name: String (Name of the destination port) | |
| - dv20rate: Float (Rate for 20ft container in USD) | |
| - dv40rate: Float (Rate for 40ft container in USD) | |
| - currency: String (Currency of the rates) | |
| - inserted_on: DateTime (Date when the rate was inserted) | |
| """ | |
| def get_csv_as_dataframe() -> str: | |
| """ | |
| Returns a string representation of the freights table as a CSV file. | |
| """ | |
| df = pd.read_sql_table("freights", engine) | |
| return df.to_csv(index=False) | |
| def run_agent(question: str) -> str: | |
| """ | |
| Run the agent with the given question. | |
| This ReAct Agent can make request to give you information about the freight data. | |
| Args: | |
| question: The question to run the agent with. | |
| Returns: | |
| The response of the agent. | |
| """ | |
| return agent.run(question,max_steps=5) | |
| if __name__ == "__main__": | |
| gr.Interface( | |
| fn=run_agent, | |
| inputs=gr.Textbox(lines=7, label="Question"), | |
| outputs=gr.Textbox(), | |
| title="Freight Agent MCP", | |
| description="Ask a question about the freight data in natural language", | |
| examples=EXAMPLE_QUERIES if "EXAMPLE_QUERIES" in globals() else None | |
| ).launch(mcp_server=True) | |