Yasu777 commited on
Commit
6330542
·
verified ·
1 Parent(s): 94e14bd

Update article_generator.py

Browse files
Files changed (1) hide show
  1. article_generator.py +24 -3
article_generator.py CHANGED
@@ -47,6 +47,23 @@ class EnhancedTavilySearchTool:
47
  else:
48
  raise Exception(f"Failed to fetch data from Tavily API: {response.status_code}, {response.text}")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # 重複を排除するヘルパー関数
51
  def remove_duplicates(text_list):
52
  seen = set()
@@ -227,10 +244,14 @@ def perform_initial_tavily_search(h2_texts, h3_texts):
227
 
228
  print("Performing Tavily search with queries:", queries)
229
  responses = tavily_search_tool.search(queries)
 
 
 
 
230
  response_dict = {}
231
  for i, query in enumerate(queries):
232
- if i < len(responses): # 応答リストの範囲内にあることを確認
233
- response_dict[query] = responses[i]
234
  else:
235
  response_dict[query] = "No response received"
236
 
@@ -440,4 +461,4 @@ def setup_gradio_interface():
440
  updated_content = display_content(latest_content.get(), format_choice)
441
  return updated_content
442
 
443
- format_selector.change(update_content, inputs=[format_selector], outputs=[content_display])
 
47
  else:
48
  raise Exception(f"Failed to fetch data from Tavily API: {response.status_code}, {response.text}")
49
 
50
+ # ドメインフィルタリング関数の定義
51
+ def filter_responses_by_domain(responses, allowed_domains):
52
+ filtered_responses = []
53
+ for response in responses:
54
+ url = response.get('url')
55
+ if any(domain in url for domain in allowed_domains):
56
+ filtered_responses.append(response)
57
+ return filtered_responses
58
+
59
+ # 許可されたドメインリスト
60
+ allowed_domains = [
61
+ '.gov', # 政府関連のサイト
62
+ '.edu', # 教育機関のサイト
63
+ '.org', # 非営利組織のサイト
64
+ '.co.jp' # 日本の企業サイト
65
+ ]
66
+
67
  # 重複を排除するヘルパー関数
68
  def remove_duplicates(text_list):
69
  seen = set()
 
244
 
245
  print("Performing Tavily search with queries:", queries)
246
  responses = tavily_search_tool.search(queries)
247
+
248
+ # フィルタリングを適用
249
+ filtered_responses = filter_responses_by_domain(responses, allowed_domains)
250
+
251
  response_dict = {}
252
  for i, query in enumerate(queries):
253
+ if i < len(filtered_responses): # 応答リストの範囲内にあることを確認
254
+ response_dict[query] = filtered_responses[i]
255
  else:
256
  response_dict[query] = "No response received"
257
 
 
461
  updated_content = display_content(latest_content.get(), format_choice)
462
  return updated_content
463
 
464
+ format_selector.change(update_content, inputs=[format_selector], outputs=[content_display])