Yasu777 commited on
Commit
349032e
·
verified ·
1 Parent(s): f993b67

Update article_generator.py

Browse files
Files changed (1) hide show
  1. article_generator.py +14 -23
article_generator.py CHANGED
@@ -28,14 +28,15 @@ class GoogleSearchTool:
28
 
29
  # Tavily APIのカスタムツールを定義
30
  class EnhancedTavilySearchTool:
31
- def search(self, query):
32
- if len(query) < 5:
33
- query += " details" # クエリを拡張して必要な文字数にする
 
34
 
35
  params = {
36
  'api_key': tavily_api_key,
37
- 'query': query,
38
- 'max_results': 5,
39
  'detail_level': 'high',
40
  'search_depth': 'advanced'
41
  }
@@ -74,8 +75,8 @@ def clear_state():
74
  research_results = []
75
  return "状態がクリアされました"
76
 
77
- # 見出しを処理する関数(キャッシュされたレスポンスを使用)
78
- def process_heading(agent, h2_text, h3_for_this_h2, cached_responses):
79
  query = f"{h2_text} {' '.join(h3_for_this_h2)}"
80
  if query in cached_responses:
81
  return (query, cached_responses[query])
@@ -85,26 +86,15 @@ def process_heading(agent, h2_text, h3_for_this_h2, cached_responses):
85
  # 初期データをTavily検索で収集する関数
86
  def perform_initial_tavily_search(h2_texts, h3_texts):
87
  tavily_search_tool = EnhancedTavilySearchTool()
88
- cached_responses = {}
89
  queries = []
90
-
91
  for h2_text in h2_texts:
92
  h3_for_this_h2 = [h3 for h3 in h3_texts if h3.startswith(f"{h2_texts.index(h2_text)+1}-")]
93
  query = f"{h2_text} {' '.join(h3_for_this_h2)}"
94
  queries.append(query)
95
-
96
- with ThreadPoolExecutor(max_workers=10) as executor:
97
- futures = {executor.submit(tavily_search_tool.search, query): query for query in queries}
98
- for future in as_completed(futures):
99
- query = futures[future]
100
- try:
101
- response = future.result()
102
- cached_responses[query] = response
103
- except Exception as e:
104
- print(f"Error occurred during Tavily search for query '{query}': {str(e)}")
105
- cached_responses[query] = str(e)
106
-
107
- return cached_responses
108
 
109
  # キャッシュされたTavilyデータを保存する関数
110
  def save_preloaded_tavily_data(data):
@@ -159,7 +149,7 @@ def generate_article(editable_output2):
159
  futures = []
160
  for h2_text in h2_texts:
161
  h3_for_this_h2 = [h3 for h3 in h3_texts if h3.startswith(f"{h2_texts.index(h2_text)+1}-")]
162
- futures.append(executor.submit(process_heading, agent, h2_text, h3_for_this_h2, cached_responses))
163
 
164
  for future in as_completed(futures):
165
  purpose, response = future.result()
@@ -311,3 +301,4 @@ def continue_generate_article():
311
  os.remove("state.json")
312
 
313
  return final_result
 
 
28
 
29
  # Tavily APIのカスタムツールを定義
30
  class EnhancedTavilySearchTool:
31
+ def search(self, queries):
32
+ combined_query = " | ".join(queries) # クエリを結合して一つのリクエストで処理
33
+ if len(combined_query) < 5:
34
+ combined_query += " details"
35
 
36
  params = {
37
  'api_key': tavily_api_key,
38
+ 'query': combined_query,
39
+ 'max_results': 50, # 必要に応じて結果の数を調整
40
  'detail_level': 'high',
41
  'search_depth': 'advanced'
42
  }
 
75
  research_results = []
76
  return "状態がクリアされました"
77
 
78
+ # 見出しを処理する関数
79
+ def process_heading(agent, h2_text, h3_for_this_h2, executed_instructions, research_results, cached_responses):
80
  query = f"{h2_text} {' '.join(h3_for_this_h2)}"
81
  if query in cached_responses:
82
  return (query, cached_responses[query])
 
86
  # 初期データをTavily検索で収集する関数
87
  def perform_initial_tavily_search(h2_texts, h3_texts):
88
  tavily_search_tool = EnhancedTavilySearchTool()
 
89
  queries = []
90
+
91
  for h2_text in h2_texts:
92
  h3_for_this_h2 = [h3 for h3 in h3_texts if h3.startswith(f"{h2_texts.index(h2_text)+1}-")]
93
  query = f"{h2_text} {' '.join(h3_for_this_h2)}"
94
  queries.append(query)
95
+
96
+ response = tavily_search_tool.search(queries) # 一回のリクエストで全てのクエリを処理
97
+ return {query: response[i] for i, query in enumerate(queries)} # 結果を適切にマッピング
 
 
 
 
 
 
 
 
 
 
98
 
99
  # キャッシュされたTavilyデータを保存する関数
100
  def save_preloaded_tavily_data(data):
 
149
  futures = []
150
  for h2_text in h2_texts:
151
  h3_for_this_h2 = [h3 for h3 in h3_texts if h3.startswith(f"{h2_texts.index(h2_text)+1}-")]
152
+ futures.append(executor.submit(process_heading, agent, h2_text, h3_for_this_h2, executed_instructions, research_results, cached_responses))
153
 
154
  for future in as_completed(futures):
155
  purpose, response = future.result()
 
301
  os.remove("state.json")
302
 
303
  return final_result
304
+