zuazo commited on
Commit
0f31aa7
·
1 Parent(s): 135f5f7

Add multi-provider LLM support (OpenAI, Anthropic, Gemini, Databricks, Ollama)

Browse files
Files changed (2) hide show
  1. agent.py +19 -10
  2. gradio-ui.py +117 -27
agent.py CHANGED
@@ -9,25 +9,34 @@ import prompts
9
  class BIDSifierAgent:
10
  """Wrapper around OpenAI chat API for step-wise BIDSification."""
11
 
12
- def __init__(self, *, provider: str, model: str, openai_api_key: Optional[str] = None, temperature: float = 0.2):
 
 
 
 
 
 
 
13
  load_dotenv()
14
 
15
- if provider=="openai":
16
- if model.startswith("gpt-5"): #reasoning model that requires special handling
17
- temperature = 1.0
18
- lm = dspy.LM(f"{provider}/{model}", api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), temperature = temperature, max_tokens = 40000)
19
- else:
20
- lm = dspy.LM(f"{provider}/{model}", api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), temperature = temperature, max_tokens = 10000)
21
  else:
22
- lm = dspy.LM(f"{provider}/{model}", api_key="", max_tokens=10000)
 
23
 
24
-
 
 
 
 
25
 
26
  dspy.configure(lm=lm)
27
  self.llm = lm
28
  self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
29
  self.temperature = temperature
30
-
31
  def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
32
  dataset_xml = context.get("dataset_xml")
33
  readme_text = context.get("readme_text")
 
9
  class BIDSifierAgent:
10
  """Wrapper around OpenAI chat API for step-wise BIDSification."""
11
 
12
+ def __init__(
13
+ self, *,
14
+ provider: str,
15
+ model: str,
16
+ api_key: Optional[str] = None,
17
+ api_base: Optional[str] = None,
18
+ temperature: Optional[float | None] = 0.2,
19
+ ):
20
  load_dotenv()
21
 
22
+ if model.startswith("gpt-5"): # reasoning model that requires special handling
23
+ temperature = 1.0
24
+ max_tokens = 40000
 
 
 
25
  else:
26
+ temperature = None
27
+ max_tokens = 10000
28
 
29
+ lm = dspy.LM(
30
+ f"{provider}/{model}",
31
+ api_key=api_key or os.getenv(f"{provider.upper()}_API_KEY"),
32
+ temperature=temperature, max_tokens=max_tokens
33
+ )
34
 
35
  dspy.configure(lm=lm)
36
  self.llm = lm
37
  self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
38
  self.temperature = temperature
39
+
40
  def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
41
  dataset_xml = context.get("dataset_xml")
42
  readme_text = context.get("readme_text")
gradio-ui.py CHANGED
@@ -36,6 +36,47 @@ STEP_LABELS = list(BIDSIFIER_STEPS.keys())
36
  NUM_STEPS = len(STEP_LABELS)
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Helpers
40
 
41
  def split_shell_commands(text: str) -> List[str]:
@@ -121,8 +162,33 @@ def build_context(
121
 
122
  # Core callbacks
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def call_bidsifier_step(
125
- openai_api_key: str,
 
126
  dataset_xml: str,
127
  readme_text: str,
128
  publication_text: str,
@@ -137,6 +203,10 @@ def call_bidsifier_step(
137
 
138
  Parameters
139
  ----------
 
 
 
 
140
  dataset_xml : str
141
  Dataset XML content.
142
  readme_text : str
@@ -185,7 +255,9 @@ def call_bidsifier_step(
185
  step_id = BIDSIFIER_STEPS[step_label]
186
  context = build_context(dataset_xml, readme_text, publication_text, output_root)
187
 
188
- agent = BIDSifierAgent(provider=provider, model=model, openai_api_key=openai_api_key)
 
 
189
 
190
  # Decide whether to use the structured step prompt or a free-form query:
191
  if manual_prompt.strip():
@@ -204,7 +276,7 @@ def call_bidsifier_step(
204
  step_index = 0
205
 
206
  state = {
207
- "openai_api_key": openai_api_key,
208
  "dataset_xml": dataset_xml,
209
  "readme_text": readme_text,
210
  "publication_text": publication_text,
@@ -217,7 +289,9 @@ def call_bidsifier_step(
217
  "commands": commands,
218
  }
219
 
220
- return llm_output, commands_str, state, step_index
 
 
221
 
222
 
223
  def confirm_commands(
@@ -288,7 +362,7 @@ def confirm_commands(
288
 
289
  agent = BIDSifierAgent(
290
  provider=last_state.get("provider", "openai"),
291
- model=last_state.get("model", "gpt-4o-mini"),
292
  )
293
 
294
  llm_output = agent.run_step(next_id, context)
@@ -479,13 +553,6 @@ with gr.Blocks(
479
  """
480
  )
481
 
482
- with gr.Row():
483
- openai_key_input = gr.Textbox(
484
- label="OpenAI API Key",
485
- placeholder="Paste your OpenAI API key here",
486
- lines=1,
487
- type="password",
488
- )
489
  with gr.Row():
490
  dataset_xml_file = gr.File(
491
  label="Upload dataset_structure.xml (optional)",
@@ -509,17 +576,35 @@ with gr.Blocks(
509
  lines=6,
510
  )
511
 
512
- with gr.Accordion("LLM settings (advanced)", open=False):
513
- provider_input = gr.Dropdown(
514
- label="Provider",
515
- choices=["openai"],
516
- value="openai",
517
- )
518
- model_input = gr.Textbox(
519
- label="Model",
520
- value="gpt-4o-mini",
521
- placeholder="e.g., gpt-4o-mini, gpt-5",
522
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  output_root_input = gr.Textbox(
525
  label="Output root",
@@ -554,10 +639,8 @@ with gr.Blocks(
554
 
555
  call_button = gr.Button("Call BIDSifier", variant="primary")
556
 
557
- llm_output_box = gr.Textbox(
558
- label="Raw BIDSifier output",
559
- lines=10,
560
- interactive=True,
561
  )
562
 
563
  commands_box = gr.Textbox(
@@ -578,10 +661,17 @@ with gr.Blocks(
578
 
579
  # Wiring
580
 
 
 
 
 
 
 
581
  call_button.click(
582
  fn=call_bidsifier_step,
583
  inputs=[
584
- openai_key_input,
 
585
  dataset_xml_input,
586
  readme_input,
587
  publication_input,
 
36
  NUM_STEPS = len(STEP_LABELS)
37
 
38
 
39
+ PROVIDERS_INFO = {
40
+ "openai": {
41
+ "model": "gpt-4o-mini",
42
+ "doc": "You can authenticate by setting the **API Key** above.",
43
+ },
44
+ "anthropic": {
45
+ "model": "claude-sonnet-4-5-20250929",
46
+ "doc": "You can authenticate by setting the **API Key** above.",
47
+ },
48
+ "databricks": {
49
+ "model": "databricks-llama-4-maverick",
50
+ "doc": (
51
+ "If you're on the Databricks platform, authentication is automatic"
52
+ " via their SDK. If not, you can set the **API Key** above."
53
+ ),
54
+ },
55
+ "gemini": {
56
+ "model": "gemini-2.5-flash",
57
+ "doc": "You can authenticate by setting the **API Key** above.",
58
+ },
59
+ "ollama_chat": {
60
+ "model": "llama3.2:1b",
61
+ "doc": """
62
+ First, install Ollama and launch its server with your LM:
63
+
64
+ ```
65
+ > curl -fsSL https://ollama.ai/install.sh | sh
66
+ > ollama run llama3.2:1b
67
+ ```
68
+
69
+ You can search all the Ollama models here: https://ollama.com/search
70
+
71
+ You can also manually input another provider. See the
72
+ [DSPy documentation](https://dspy.ai/api/models/LM/) for other providers and
73
+ local LLMs.
74
+ """,
75
+ "key_base": "http://localhost:11434",
76
+ },
77
+ }
78
+
79
+
80
  # Helpers
81
 
82
  def split_shell_commands(text: str) -> List[str]:
 
162
 
163
  # Core callbacks
164
 
165
+ def select_provider(provider: str):
166
+ """Fills the default provider model base, and description.
167
+
168
+ Parameters
169
+ ----------
170
+ provider : str
171
+ The name of the provider.
172
+
173
+ Returns
174
+ -------
175
+ list
176
+ A list of two values: model name, and provider documentation in MD.
177
+ """
178
+ provider = provider.lower()
179
+ if provider not in PROVIDERS_INFO:
180
+ return "", "", ""
181
+ provider_info = PROVIDERS_INFO[provider]
182
+ return (
183
+ provider_info.get("model", ""),
184
+ provider_info.get("key_base", ""),
185
+ provider_info.get("doc", ""),
186
+ )
187
+
188
+
189
  def call_bidsifier_step(
190
+ api_key: str,
191
+ api_base: str,
192
  dataset_xml: str,
193
  readme_text: str,
194
  publication_text: str,
 
203
 
204
  Parameters
205
  ----------
206
+ api_key : str
207
+ LLM provider API key.
208
+ api_base : str
209
+ Local LLM base URL.
210
  dataset_xml : str
211
  Dataset XML content.
212
  readme_text : str
 
255
  step_id = BIDSIFIER_STEPS[step_label]
256
  context = build_context(dataset_xml, readme_text, publication_text, output_root)
257
 
258
+ agent = BIDSifierAgent(
259
+ provider=provider, model=model, api_key=api_key, api_base=api_base
260
+ )
261
 
262
  # Decide whether to use the structured step prompt or a free-form query:
263
  if manual_prompt.strip():
 
276
  step_index = 0
277
 
278
  state = {
279
+ "api_key": api_key,
280
  "dataset_xml": dataset_xml,
281
  "readme_text": readme_text,
282
  "publication_text": publication_text,
 
289
  "commands": commands,
290
  }
291
 
292
+ llm_output_text = "## Raw BIDSifier output\n" + llm_output
293
+
294
+ return llm_output_text, commands_str, state, step_index
295
 
296
 
297
  def confirm_commands(
 
362
 
363
  agent = BIDSifierAgent(
364
  provider=last_state.get("provider", "openai"),
365
+ model=last_state.get("model", PROVIDERS_INFO["openai"]["model"]),
366
  )
367
 
368
  llm_output = agent.run_step(next_id, context)
 
553
  """
554
  )
555
 
 
 
 
 
 
 
 
556
  with gr.Row():
557
  dataset_xml_file = gr.File(
558
  label="Upload dataset_structure.xml (optional)",
 
576
  lines=6,
577
  )
578
 
579
+ with gr.Accordion("LLM settings"):
580
+ api_key_input = gr.Textbox(
581
+ label="API Key (OpenAI, Gemini, Anthropic, Databricks)",
582
+ placeholder=(
583
+ "Paste your LLM provider API key here."
584
+ ' See "Advanced settings" to set up non-OpenAI providers.'
585
+ ),
586
+ lines=1,
587
+ type="password",
 
588
  )
589
+ with gr.Accordion("Advanced settings", open=False):
590
+ default_provider = list(PROVIDERS_INFO.keys())[0]
591
+ provider_input = gr.Dropdown(
592
+ label="Provider",
593
+ choices=PROVIDERS_INFO.keys(),
594
+ value=default_provider,
595
+ interactive=True,
596
+ )
597
+ model_input = gr.Textbox(
598
+ label="Model",
599
+ value=PROVIDERS_INFO[default_provider]["model"],
600
+ placeholder="e.g., gpt-4o-mini, gpt-5",
601
+ )
602
+ api_base_input = gr.Textbox(
603
+ label="API Base (Local LLMs, other providers)",
604
+ placeholder="Write your API Base URL here",
605
+ lines=1,
606
+ )
607
+ model_doc = gr.Markdown(PROVIDERS_INFO[default_provider]["doc"])
608
 
609
  output_root_input = gr.Textbox(
610
  label="Output root",
 
639
 
640
  call_button = gr.Button("Call BIDSifier", variant="primary")
641
 
642
+ llm_output_box = gr.Markdown(
643
+ label="Raw BIDSifier output", show_label=True, container=True
 
 
644
  )
645
 
646
  commands_box = gr.Textbox(
 
661
 
662
  # Wiring
663
 
664
+ provider_input.change(
665
+ fn=select_provider,
666
+ inputs=provider_input,
667
+ outputs=[model_input, api_base_input, model_doc],
668
+ )
669
+
670
  call_button.click(
671
  fn=call_bidsifier_step,
672
  inputs=[
673
+ api_key_input,
674
+ api_base_input,
675
  dataset_xml_input,
676
  readme_input,
677
  publication_input,