TaoZewen commited on
Commit
e242d6a
·
1 Parent(s): d1fa61b

prompt fix

Browse files
Files changed (1) hide show
  1. gemini_agent.py +51 -2
gemini_agent.py CHANGED
@@ -294,6 +294,49 @@ def wikipedia_qa(question: str) -> str:
294
  out = qa_pipeline(question=question, context=page)
295
  return out["answer"]
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  class GeminiAgent:
298
  def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
299
  # Suppress warnings
@@ -344,7 +387,13 @@ class GeminiAgent:
344
  name="wikipedia_qa",
345
  func=wikipedia_qa,
346
  description="给定问题 + 维基文章,直接抽取精确答案"
347
- )
 
 
 
 
 
 
348
  ]
349
 
350
  # Setup memory
@@ -589,7 +638,7 @@ Observation: the result of the action
589
  2. Action: wikipedia_qa
590
  Action Input: <original user question>
591
  Observation: <extracted fact or number>
592
-
593
  When you have the final answer, you MUST use exactly one line with this format (no extra text, no lists):
594
 
595
  Thought: Do I need to use a tool? No
 
294
  out = qa_pipeline(question=question, context=page)
295
  return out["answer"]
296
 
297
+ import re
298
+
299
+ def count_studio_albums(query: str) -> str:
300
+ """
301
+ 解析维基 'Studio albums' 段落,统计 <artist> 在 start–end 年份之间出的专辑数。
302
+ query 格式示例:"How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
303
+ """
304
+ # 1) 从问题里提取艺术家和年份区间
305
+ m = re.search(
306
+ r"studio albums were published by (.+?) between (\d{4}) and (\d{4})",
307
+ query, flags=re.IGNORECASE
308
+ )
309
+ if not m:
310
+ return "" # 不匹配就空
311
+ artist, start, end = m.group(1), int(m.group(2)), int(m.group(3))
312
+
313
+ # 2) 抓维基全文,找 'Studio albums' 小节
314
+ try:
315
+ page = wikipedia.page(artist)
316
+ except Exception as e:
317
+ return f"Error fetching wiki page: {e}"
318
+ text = page.content
319
+
320
+ # 3) 拆段落
321
+ parts = re.split(r"\n==+\s*Studio albums\s*==+", text)
322
+ if len(parts) < 2:
323
+ return "0"
324
+ section = parts[1]
325
+
326
+ # 4) 从每行 `* Album Name (YYYY)` 抽年份、计数
327
+ years = []
328
+ for line in section.splitlines():
329
+ line = line.strip()
330
+ if not line.startswith("*"):
331
+ continue
332
+ y_m = re.search(r"\((\d{4})\)", line)
333
+ if y_m:
334
+ y = int(y_m.group(1))
335
+ years.append(y)
336
+
337
+ count = sum(1 for y in years if start <= y <= end)
338
+ return str(count)
339
+
340
  class GeminiAgent:
341
  def __init__(self, api_key: str, model_name: str = "gemini-2.0-flash"):
342
  # Suppress warnings
 
387
  name="wikipedia_qa",
388
  func=wikipedia_qa,
389
  description="给定问题 + 维基文章,直接抽取精确答案"
390
+ ),
391
+ Tool(
392
+ name="count_studio_albums",
393
+ func=count_studio_albums,
394
+ description=("统计某位艺术家在指定年份区间内发布的 Studio albums 数量,"
395
+ "query 必须包含 “studio albums … between YYYY and YYYY”。")
396
+ )
397
  ]
398
 
399
  # Setup memory
 
638
  2. Action: wikipedia_qa
639
  Action Input: <original user question>
640
  Observation: <extracted fact or number>
641
+
642
  When you have the final answer, you MUST use exactly one line with this format (no extra text, no lists):
643
 
644
  Thought: Do I need to use a tool? No