stefanches7 commited on
Commit
8d0fd38
·
1 Parent(s): 6fc7119

use dspy to abstract from model provider

Browse files
Files changed (2) hide show
  1. agent.py +12 -8
  2. cli.py +3 -2
agent.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, Dict, Any
2
  import os
3
  from dotenv import load_dotenv
4
 
 
5
  from openai import OpenAI
6
 
7
  import prompts as prompts_mod
@@ -10,11 +11,15 @@ import prompts as prompts_mod
10
  class BIDSifierAgent:
11
  """Wrapper around OpenAI chat API for step-wise BIDSification."""
12
 
13
- def __init__(self, *, model: Optional[str] = None, temperature: float = 0.2):
14
  load_dotenv()
15
- if not os.getenv("OPENAI_API_KEY"):
16
- raise RuntimeError("OPENAI_API_KEY not set in environment.")
17
- self.client = OpenAI()
 
 
 
 
18
  self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
19
  self.temperature = temperature
20
 
@@ -63,15 +68,14 @@ class BIDSifierAgent:
63
  def run_step(self, step: str, context: Dict[str, Any]) -> str:
64
  system_msg = prompts_mod.system_prompt()
65
  user_msg = self._build_user_prompt(step, context)
66
- resp = self.client.chat.completions.create(
67
- model=self.model,
68
- temperature=self.temperature,
69
  messages=[
70
  {"role": "system", "content": system_msg},
71
  {"role": "user", "content": user_msg},
72
  ],
 
73
  )
74
- return resp.choices[0].message.content
75
 
76
 
77
  __all__ = ["BIDSifierAgent"]
 
2
  import os
3
  from dotenv import load_dotenv
4
 
5
+ import dspy
6
  from openai import OpenAI
7
 
8
  import prompts as prompts_mod
 
11
  class BIDSifierAgent:
12
  """Wrapper around OpenAI chat API for step-wise BIDSification."""
13
 
14
+ def __init__(self, *, provider: Optional[str] = None, model: Optional[str] = None, temperature: float = 0.2):
15
  load_dotenv()
16
+
17
+ if provider=="openai":
18
+ lm = dspy.LM(f"{provider}/{model}", api_key=os.getenv("OPENAI_API_KEY"))
19
+ else:
20
+ lm = dspy.LM(f"{provider}/{model}", api_key="")
21
+ dspy.configure(lm=lm)
22
+ self.llm = lm
23
  self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
24
  self.temperature = temperature
25
 
 
68
  def run_step(self, step: str, context: Dict[str, Any]) -> str:
69
  system_msg = prompts_mod.system_prompt()
70
  user_msg = self._build_user_prompt(step, context)
71
+ resp = self.llm(
 
 
72
  messages=[
73
  {"role": "system", "content": system_msg},
74
  {"role": "user", "content": user_msg},
75
  ],
76
+ temperature=self.temperature,
77
  )
78
+ return resp[0]
79
 
80
 
81
  __all__ = ["BIDSifierAgent"]
cli.py CHANGED
@@ -65,7 +65,8 @@ def main(argv: Optional[List[str]] = None) -> int:
65
  parser.add_argument("--readme", dest="readme_path", help="Path to dataset README file", required=False)
66
  parser.add_argument("--publication", dest="publication_path", help="Path to a publication/notes file", required=False)
67
  parser.add_argument("--output-root", dest="output_root", help="Target BIDS root directory", required=True)
68
- parser.add_argument("--model", dest="model", help="OpenAI model name", default=os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini"))
 
69
  # Execution is intentionally disabled; we only display commands.
70
  # Keeping --dry-run for backward compatibility (no effect other than display).
71
  parser.add_argument("--dry-run", dest="dry_run", help="Display-only (default behavior)", action="store_true")
@@ -93,7 +94,7 @@ def main(argv: Optional[List[str]] = None) -> int:
93
  if args.publication_path:
94
  command_env["PUBLICATION_PATH"] = os.path.abspath(args.publication_path)
95
 
96
- agent = BIDSifierAgent(model=args.model)
97
 
98
  short_divider("Step 1: Understand dataset")
99
  summary = agent.run_step("summary", context)
 
65
  parser.add_argument("--readme", dest="readme_path", help="Path to dataset README file", required=False)
66
  parser.add_argument("--publication", dest="publication_path", help="Path to a publication/notes file", required=False)
67
  parser.add_argument("--output-root", dest="output_root", help="Target BIDS root directory", required=True)
68
+ parser.add_argument("--provider", dest="provider", help="Provider name or identifier, default OpeanAI", required=False, default="openai")
69
+ parser.add_argument("--model", dest="model", help="Model name to use", default=os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini"))
70
  # Execution is intentionally disabled; we only display commands.
71
  # Keeping --dry-run for backward compatibility (no effect other than display).
72
  parser.add_argument("--dry-run", dest="dry_run", help="Display-only (default behavior)", action="store_true")
 
94
  if args.publication_path:
95
  command_env["PUBLICATION_PATH"] = os.path.abspath(args.publication_path)
96
 
97
+ agent = BIDSifierAgent(provider=args.provider, model=args.model)
98
 
99
  short_divider("Step 1: Understand dataset")
100
  summary = agent.run_step("summary", context)