catninja123 commited on
Commit
7f1fe1b
·
verified ·
1 Parent(s): a608a65

Add feeder_school_info + summer_program_recommendation tools (v1.1)

Browse files
Files changed (1) hide show
  1. agent/agentic_engine.py +214 -0
agent/agentic_engine.py CHANGED
@@ -248,6 +248,53 @@ TOOLS = [
248
  },
249
  },
250
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  ]
252
 
253
  # ─── 工具执行层 ────────────────────────────────────────────────────────────────
@@ -275,6 +322,10 @@ def _execute_tool(tool_name: str, tool_input: Dict) -> str:
275
  return _tool_calibrate_school_list(**tool_input)
276
  elif tool_name == "get_special_case_patterns":
277
  return _tool_get_special_case_patterns(**tool_input)
 
 
 
 
278
  else:
279
  return json.dumps({"error": f"Unknown tool: {tool_name}"})
280
  except Exception as e:
@@ -627,6 +678,169 @@ def _tool_get_special_case_patterns(pattern_type: str, major: str = "") -> str:
627
  }, ensure_ascii=False)
628
 
629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  # ─── SSE 工具 ───────��──────────────────────────────────────────────────────────
631
 
632
  def _sse(event: str, data: Any) -> str:
 
248
  },
249
  },
250
  },
251
+ {
252
+ "type": "function",
253
+ "function": {
254
+ "name": "get_feeder_school_info",
255
+ "description": (
256
+ "查询某所大学的大陆高中 feeder school 分布(哪些高中送了多少人),"
257
+ "或查询某所高中的录取档案(该高中历史上送往哪些大学、T10/T15 率)。"
258
+ "数据来源:meiben 27000+ 条大陆录取记录。"
259
+ ),
260
+ "parameters": {
261
+ "type": "object",
262
+ "properties": {
263
+ "school_name": {
264
+ "type": "string",
265
+ "description": "大学英文名(如 'Stanford University')",
266
+ },
267
+ "hs_name": {
268
+ "type": "string",
269
+ "description": "高中中文名(可选),若提供则返回该高中的录取档案",
270
+ },
271
+ },
272
+ "required": ["school_name"],
273
+ },
274
+ },
275
+ },
276
+ {
277
+ "type": "function",
278
+ "function": {
279
+ "name": "get_summer_program_recommendation",
280
+ "description": (
281
+ "根据学生的专业方向和目标档次,推荐适合的夏校和竞赛项目。"
282
+ "数据来源:meiben 录取数据中提取的夏校/竞赛给力指数(1-10分)。"
283
+ ),
284
+ "parameters": {
285
+ "type": "object",
286
+ "properties": {
287
+ "major": {"type": "string", "description": "专业方向"},
288
+ "sat": {"type": "number", "description": "SAT 分数(可选)"},
289
+ "target_tier": {
290
+ "type": "string",
291
+ "description": "目标档次:T10 / T15 / T20",
292
+ },
293
+ },
294
+ "required": ["major"],
295
+ },
296
+ },
297
+ },
298
  ]
299
 
300
  # ─── 工具执行层 ────────────────────────────────────────────────────────────────
 
322
  return _tool_calibrate_school_list(**tool_input)
323
  elif tool_name == "get_special_case_patterns":
324
  return _tool_get_special_case_patterns(**tool_input)
325
+ elif tool_name == "get_feeder_school_info":
326
+ return _tool_get_feeder_school_info(**tool_input)
327
+ elif tool_name == "get_summer_program_recommendation":
328
+ return _tool_get_summer_program_recommendation(**tool_input)
329
  else:
330
  return json.dumps({"error": f"Unknown tool: {tool_name}"})
331
  except Exception as e:
 
678
  }, ensure_ascii=False)
679
 
680
 
681
+ def _tool_get_feeder_school_info(school_name: str, hs_name: str = None) -> str:
682
+ """
683
+ 查询某所大学的大陆高中 feeder school 分布,或某所高中的录取档案。
684
+ school_name: 大学英文名(如 'Stanford University')
685
+ hs_name: 可选,高中中文名(如 '北京师范大学附属实验中学国际部'),若提供则返回该高中的录取档案
686
+ """
687
+ feeder_file = DATA_DIR / "meiben_feeder_kb.json"
688
+ if not feeder_file.exists():
689
+ return json.dumps({"error": "meiben_feeder_kb.json not found"})
690
+
691
+ with open(feeder_file) as f:
692
+ kb = json.load(f)
693
+
694
+ # 查询高中档案
695
+ if hs_name:
696
+ # 模糊匹配高中名
697
+ hs_profiles = kb.get("high_school_profiles", {})
698
+ matched = None
699
+ for name, profile in hs_profiles.items():
700
+ if hs_name in name or name in hs_name:
701
+ matched = (name, profile)
702
+ break
703
+ if not matched:
704
+ # 尝试更宽松的匹配
705
+ hs_lower = hs_name.lower().replace(" ", "")
706
+ for name, profile in hs_profiles.items():
707
+ if hs_lower in name.lower().replace(" ", ""):
708
+ matched = (name, profile)
709
+ break
710
+ if matched:
711
+ name, profile = matched
712
+ return json.dumps({
713
+ "hs_name": name,
714
+ "province": profile.get("province"),
715
+ "total_offers": profile.get("total_offers"),
716
+ "t10_ivy_offers": profile.get("t10_ivy_offers"),
717
+ "t15_offers": profile.get("t15_offers"),
718
+ "t10_rate": profile.get("t10_rate"),
719
+ "t15_rate": profile.get("t15_rate"),
720
+ "top_universities": profile.get("top_universities", [])[:12],
721
+ "years_active": profile.get("years_active"),
722
+ }, ensure_ascii=False)
723
+ return json.dumps({"error": f"High school '{hs_name}' not found in database"})
724
+
725
+ # 查询大学的 feeder school 分布
726
+ feeder_kb = kb.get("feeder_school_kb", {})
727
+ # 模糊匹配大学名
728
+ matched_uni = None
729
+ for uni, data in feeder_kb.items():
730
+ if school_name.lower() in uni.lower() or uni.lower() in school_name.lower():
731
+ matched_uni = (uni, data)
732
+ break
733
+ if not matched_uni:
734
+ # 尝试缩写匹配
735
+ abbrev_map = {
736
+ "stanford": "Stanford University",
737
+ "harvard": "Harvard University",
738
+ "mit": "Massachusetts Institute of Technology",
739
+ "yale": "Yale University",
740
+ "princeton": "Princeton University",
741
+ "columbia": "Columbia University",
742
+ "upenn": "University of Pennsylvania",
743
+ "penn": "University of Pennsylvania",
744
+ "duke": "Duke University",
745
+ "dartmouth": "Dartmouth College",
746
+ "brown": "Brown University",
747
+ "cornell": "Cornell University",
748
+ "washu": "Washington University in St.Louis",
749
+ "wustl": "Washington University in St.Louis",
750
+ }
751
+ for abbrev, full_name in abbrev_map.items():
752
+ if abbrev in school_name.lower():
753
+ for uni, data in feeder_kb.items():
754
+ if full_name.lower() in uni.lower():
755
+ matched_uni = (uni, data)
756
+ break
757
+ if matched_uni:
758
+ break
759
+
760
+ if not matched_uni:
761
+ available = list(feeder_kb.keys())[:20]
762
+ return json.dumps({"error": f"University '{school_name}' not found", "available_universities": available})
763
+
764
+ uni_name, data = matched_uni
765
+ # 同时返回省市偏好
766
+ province_pref = kb.get("province_preference", {}).get(uni_name, {})
767
+
768
+ return json.dumps({
769
+ "university": uni_name,
770
+ "total_mainland_offers": data.get("total_offers"),
771
+ "top_feeder_schools": data.get("top_feeder_schools", [])[:10],
772
+ "province_distribution": data.get("province_distribution", {}),
773
+ "year_trend": data.get("year_trend", {}),
774
+ "province_preference": province_pref,
775
+ "insight": f"{uni_name} 在大陆的主要来源高中集中在 {', '.join(list(data.get('province_distribution', {}).keys())[:3])} 等省市。"
776
+ }, ensure_ascii=False)
777
+
778
+
779
+ def _tool_get_summer_program_recommendation(major: str, sat: float = None,
780
+ hs_type: str = "国际高中",
781
+ target_tier: str = "T15") -> str:
782
+ """
783
+ 根据学生背景推荐适合的夏校和竞赛项目。
784
+ major: 专业方向
785
+ sat: SAT 分数(可选)
786
+ hs_type: 高中类型
787
+ target_tier: 目标档次 T10/T15/T20
788
+ """
789
+ feeder_file = DATA_DIR / "meiben_feeder_kb.json"
790
+ scores_file = DATA_DIR / "summer_program_scores.json"
791
+
792
+ programs = []
793
+
794
+ # 从 meiben_feeder_kb 的 summer_program_db 获取项目列表
795
+ if feeder_file.exists():
796
+ with open(feeder_file) as f:
797
+ kb = json.load(f)
798
+ all_programs = kb.get("summer_program_db", [])
799
+
800
+ # 按专业方向过滤
801
+ major_lower = major.lower()
802
+ major_keywords = {
803
+ "cs": ["computer", "coding", "programming", "software", "ai", "data"],
804
+ "bio": ["biology", "biomedical", "life science", "medicine", "health"],
805
+ "physics": ["physics", "astronomy", "astrophysics"],
806
+ "econ": ["economics", "business", "finance", "policy"],
807
+ "env": ["environment", "sustainability", "climate", "earth"],
808
+ "humanities": ["history", "literature", "philosophy", "writing", "language"],
809
+ "math": ["math", "statistics", "quantitative"],
810
+ "engineering": ["engineering", "mechanical", "electrical", "civil"],
811
+ }
812
+
813
+ # 确定专业类别
814
+ major_cat = "general"
815
+ for cat, keywords in major_keywords.items():
816
+ if any(kw in major_lower for kw in keywords):
817
+ major_cat = cat
818
+ break
819
+
820
+ # 按给力指数排序,过滤出高质量项目
821
+ min_geili = 7.0 if target_tier == "T10" else 6.0 if target_tier == "T15" else 5.0
822
+ filtered = [p for p in all_programs if p.get("avg_geili_score", 0) >= min_geili]
823
+ filtered.sort(key=lambda x: x.get("avg_geili_score", 0), reverse=True)
824
+
825
+ # 分类型返回
826
+ summer_progs = [p for p in filtered if p.get("type") == "summer_program"][:8]
827
+ competitions = [p for p in filtered if p.get("type") == "competition"][:6]
828
+
829
+ programs = {
830
+ "summer_programs": summer_progs,
831
+ "competitions": competitions,
832
+ "total_filtered": len(filtered),
833
+ "filter_criteria": f"给力指数 >= {min_geili}({target_tier} 目标)",
834
+ }
835
+
836
+ return json.dumps({
837
+ "major": major,
838
+ "target_tier": target_tier,
839
+ "recommendations": programs,
840
+ "note": "给力指数 10=顶级(RSI/ISEF),8-9=强力,6-7=良好,5以下=一般",
841
+ }, ensure_ascii=False)
842
+
843
+
844
  # ─── SSE 工具 ───────��──────────────────────────────────────────────────────────
845
 
846
  def _sse(event: str, data: Any) -> str: