Laramie2 commited on
Commit
6b075de
·
verified ·
1 Parent(s): 89d15ac

Update src/paper2DAG.py

Browse files
Files changed (1) hide show
  1. src/paper2DAG.py +39 -31
src/paper2DAG.py CHANGED
@@ -16,13 +16,15 @@ def clean_paper(markdown_path, clean_prompt, model, config):
16
  删除 Abstract / Related Work / Appendix / References 等部分,
17
  保留标题、作者、Introduction、Methods、Experiments、Conclusion。
18
  """
19
- # === 可选:使用 config 中的 api_base_url 覆盖默认值 ===
20
- if config['api_base_url'] is not None and config['api_base_url'].strip():
21
- genai.config.api_base = config['api_base_url'].strip().rstrip("/") + "/v1"
22
- print(f"🔧 Using custom API base: {genai.config.api_base}")
23
- # === 初始化 Client ===
 
24
  client = genai.Client(
25
- api_key=config['api_keys']['gemini_api_key']
 
26
  )
27
 
28
  # === 读取 markdown 文件 ===
@@ -94,11 +96,6 @@ def split_paper(
94
  使用 Gemini 拆分论文,并将所有拆分后的 markdown 保存在:
95
  <parent_of_auto>/section_split_output/
96
  """
97
- # === 可选:使用 config 中的 api_base_url 覆盖默认值 ===
98
- if config['api_base_url'] is not None and config['api_base_url'].strip():
99
- genai.config.api_base = config['api_base_url'].strip().rstrip("/") + "/v1"
100
- print(f"🔧 Using custom API base: {genai.config.api_base}")
101
-
102
  # 1️⃣ 输入文件所在的 auto 文件夹
103
  auto_dir = os.path.dirname(os.path.abspath(cleaned_md_path))
104
 
@@ -114,8 +111,15 @@ def split_paper(
114
  markdown_text = f.read()
115
 
116
  # === 2. 初始化 Gemini Client ===
 
 
 
 
 
 
117
  client = genai.Client(
118
- api_key=config['api_keys']['gemini_api_key']
 
119
  )
120
 
121
  # === 提取一级 section 信息供参考 (假设 SECTION_RE 已在外部定义) ===
@@ -208,14 +212,16 @@ def initialize_dag(markdown_path, initialize_dag_prompt, model, config=None):
208
  with open(markdown_path, "r", encoding="utf-8") as f:
209
  md_text = f.read()
210
 
211
- # === 可选:使用 config 中的 api_base_url 覆盖默认值 ===
212
- if config['api_base_url'] is not None and config['api_base_url'].strip():
213
- genai.config.api_base = config['api_base_url'].strip().rstrip("/") + "/v1"
214
- print(f"🔧 Using custom API base: {genai.config.api_base}")
215
-
216
  # --- Gemini Client Init ---
 
 
 
 
 
 
217
  client = genai.Client(
218
- api_key=config['api_keys']['gemini_api_key']
 
219
  )
220
 
221
  # --- Gemini Call ---
@@ -314,14 +320,15 @@ def extract_and_generate_visual_dag(
314
  normalized_refs = [f"![]({m})" for m in relative_imgs]
315
 
316
  # === 3. 发送给 Gemini ===
317
- # === 可选:使用 config 中的 api_base_url 覆盖默认值 ===
318
- if config['api_base_url'] is not None and config['api_base_url'].strip():
319
- genai.config.api_base = config['api_base_url'].strip().rstrip("/") + "/v1"
320
- print(f"🔧 Using custom API base: {genai.config.api_base}")
321
-
322
  # 初始化 Client
323
  client = genai.Client(
324
- api_key=config['api_keys']['gemini_api_key']
 
325
  )
326
 
327
  gpt_input = prompt_for_gpt + "\n\n" + \
@@ -544,15 +551,16 @@ def build_section_dags(
544
  FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT = True
545
  MAX_RETRIES_ON_FAIL = 2
546
 
547
- # === Init Client (Gemini) ===
548
- # === 可选:使用 config 中的 api_base_url 覆盖默认值 ===
549
- if config['api_base_url'] is not None and config['api_base_url'].strip():
550
- genai.config.api_base = config['api_base_url'].strip().rstrip("/") + "/v1"
551
- print(f"🔧 Using custom API base: {genai.config.api_base}")
552
-
553
  # 使用 config 中的 key
554
  client = genai.Client(
555
- api_key=config['api_keys']['gemini_api_key']
 
556
  )
557
 
558
  def build_full_prompt(base_prompt: str, section_name: str, md_text: str) -> str:
 
16
  删除 Abstract / Related Work / Appendix / References 等部分,
17
  保留标题、作者、Introduction、Methods、Experiments、Conclusion。
18
  """
19
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
20
+ if raw_url.endswith("/v1"):
21
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
22
+ else:
23
+ base_url = raw_url
24
+
25
  client = genai.Client(
26
+ api_key=config['api_keys']['gemini_api_key'],
27
+ http_options={'base_url': base_url} if base_url else None
28
  )
29
 
30
  # === 读取 markdown 文件 ===
 
96
  使用 Gemini 拆分论文,并将所有拆分后的 markdown 保存在:
97
  <parent_of_auto>/section_split_output/
98
  """
 
 
 
 
 
99
  # 1️⃣ 输入文件所在的 auto 文件夹
100
  auto_dir = os.path.dirname(os.path.abspath(cleaned_md_path))
101
 
 
111
  markdown_text = f.read()
112
 
113
  # === 2. 初始化 Gemini Client ===
114
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
115
+ if raw_url.endswith("/v1"):
116
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
117
+ else:
118
+ base_url = raw_url
119
+
120
  client = genai.Client(
121
+ api_key=config['api_keys']['gemini_api_key'],
122
+ http_options={'base_url': base_url} if base_url else None
123
  )
124
 
125
  # === 提取一级 section 信息供参考 (假设 SECTION_RE 已在外部定义) ===
 
212
  with open(markdown_path, "r", encoding="utf-8") as f:
213
  md_text = f.read()
214
 
 
 
 
 
 
215
  # --- Gemini Client Init ---
216
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
217
+ if raw_url.endswith("/v1"):
218
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
219
+ else:
220
+ base_url = raw_url
221
+
222
  client = genai.Client(
223
+ api_key=config['api_keys']['gemini_api_key'],
224
+ http_options={'base_url': base_url} if base_url else None
225
  )
226
 
227
  # --- Gemini Call ---
 
320
  normalized_refs = [f"![]({m})" for m in relative_imgs]
321
 
322
  # === 3. 发送给 Gemini ===
323
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
324
+ if raw_url.endswith("/v1"):
325
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
326
+ else:
327
+ base_url = raw_url
328
  # 初始化 Client
329
  client = genai.Client(
330
+ api_key=config['api_keys']['gemini_api_key'],
331
+ http_options={'base_url': base_url} if base_url else None
332
  )
333
 
334
  gpt_input = prompt_for_gpt + "\n\n" + \
 
551
  FALLBACK_STRIP_BACKSLASH_ONLY_IN_CONTENT = True
552
  MAX_RETRIES_ON_FAIL = 2
553
 
554
+ # === Init Client (Gemini) ===
555
+ raw_url = config.get('api_base_url', '').strip().rstrip("/")
556
+ if raw_url.endswith("/v1"):
557
+ base_url = raw_url[:-3].rstrip("/") # 去掉最后的 /v1
558
+ else:
559
+ base_url = raw_url
560
  # 使用 config 中的 key
561
  client = genai.Client(
562
+ api_key=config['api_keys']['gemini_api_key'],
563
+ http_options={'base_url': base_url} if base_url else None
564
  )
565
 
566
  def build_full_prompt(base_prompt: str, section_name: str, md_text: str) -> str: