| from schema_filter import filter_func, SchemaItemClassifierInference | |
| # 在eval模式下,sql不用提供 | |
| data = { | |
| "text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.", | |
| "sql": "", | |
| "schema": { | |
| "schema_items": [ | |
| { | |
| "table_name": "lists", | |
| "table_comment": "", | |
| "column_names": [ | |
| "user_id", | |
| "list_id", | |
| "list_title", | |
| "list_movie_number", | |
| "list_update_timestamp_utc", | |
| "list_creation_timestamp_utc", | |
| "list_followers", | |
| "list_url", | |
| "list_comments", | |
| "list_description", | |
| "list_cover_image_url", | |
| "list_first_image_url", | |
| "list_second_image_url", | |
| "list_third_image_url" | |
| ], | |
| "column_comments": [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| }, | |
| { | |
| "table_name": "movies", | |
| "table_comment": "", | |
| "column_names": [ | |
| "movie_id", | |
| "movie_title", | |
| "movie_release_year", | |
| "movie_url", | |
| "movie_title_language", | |
| "movie_popularity", | |
| "movie_image_url", | |
| "director_id", | |
| "director_name", | |
| "director_url" | |
| ], | |
| "column_comments": [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| }, | |
| { | |
| "table_name": "ratings_users", | |
| "table_comment": "", | |
| "column_names": [ | |
| "user_id", | |
| "rating_date_utc", | |
| "user_trialist", | |
| "user_subscriber", | |
| "user_avatar_image_url", | |
| "user_cover_image_url", | |
| "user_eligible_for_trial", | |
| "user_has_payment_method" | |
| ], | |
| "column_comments": [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| }, | |
| { | |
| "table_name": "lists_users", | |
| "table_comment": "", | |
| "column_names": [ | |
| "user_id", | |
| "list_id", | |
| "list_update_date_utc", | |
| "list_creation_date_utc", | |
| "user_trialist", | |
| "user_subscriber", | |
| "user_avatar_image_url", | |
| "user_cover_image_url", | |
| "user_eligible_for_trial", | |
| "user_has_payment_method" | |
| ], | |
| "column_comments": [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| }, | |
| { | |
| "table_name": "ratings", | |
| "table_comment": "", | |
| "column_names": [ | |
| "movie_id", | |
| "rating_id", | |
| "rating_url", | |
| "rating_score", | |
| "rating_timestamp_utc", | |
| "critic", | |
| "critic_likes", | |
| "critic_comments", | |
| "user_id", | |
| "user_trialist", | |
| "user_subscriber", | |
| "user_eligible_for_trial", | |
| "user_has_payment_method" | |
| ], | |
| "column_comments": [ | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "" | |
| ] | |
| } | |
| ] | |
| } | |
| } | |
| dataset = [data] | |
| # 最多保留数据库中的7张表 | |
| num_top_k_tables = 7 | |
| # 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列 | |
| num_top_k_columns = 10 | |
| # 加载分类器模型 | |
| sic = SchemaItemClassifierInference("sic_merged") | |
| # 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分 | |
| dataset = filter_func( | |
| dataset = dataset, | |
| dataset_type = "eval", | |
| sic = sic, | |
| num_top_k_tables = num_top_k_tables, | |
| num_top_k_columns = num_top_k_columns | |
| ) | |
| print(dataset) |