Spaces:
Sleeping
Sleeping
Add multi-provider LLM support (OpenAI, Anthropic, Gemini, Databricks, Ollama)
Browse files- agent.py +19 -10
- gradio-ui.py +117 -27
agent.py
CHANGED
|
@@ -9,25 +9,34 @@ import prompts
|
|
| 9 |
class BIDSifierAgent:
|
| 10 |
"""Wrapper around OpenAI chat API for step-wise BIDSification."""
|
| 11 |
|
| 12 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
load_dotenv()
|
| 14 |
|
| 15 |
-
if
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
lm = dspy.LM(f"{provider}/{model}", api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), temperature = temperature, max_tokens = 40000)
|
| 19 |
-
else:
|
| 20 |
-
lm = dspy.LM(f"{provider}/{model}", api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), temperature = temperature, max_tokens = 10000)
|
| 21 |
else:
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
dspy.configure(lm=lm)
|
| 27 |
self.llm = lm
|
| 28 |
self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
|
| 29 |
self.temperature = temperature
|
| 30 |
-
|
| 31 |
def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
|
| 32 |
dataset_xml = context.get("dataset_xml")
|
| 33 |
readme_text = context.get("readme_text")
|
|
|
|
| 9 |
class BIDSifierAgent:
|
| 10 |
"""Wrapper around OpenAI chat API for step-wise BIDSification."""
|
| 11 |
|
| 12 |
+
def __init__(
|
| 13 |
+
self, *,
|
| 14 |
+
provider: str,
|
| 15 |
+
model: str,
|
| 16 |
+
api_key: Optional[str] = None,
|
| 17 |
+
api_base: Optional[str] = None,
|
| 18 |
+
temperature: Optional[float | None] = 0.2,
|
| 19 |
+
):
|
| 20 |
load_dotenv()
|
| 21 |
|
| 22 |
+
if model.startswith("gpt-5"): # reasoning model that requires special handling
|
| 23 |
+
temperature = 1.0
|
| 24 |
+
max_tokens = 40000
|
|
|
|
|
|
|
|
|
|
| 25 |
else:
|
| 26 |
+
temperature = None
|
| 27 |
+
max_tokens = 10000
|
| 28 |
|
| 29 |
+
lm = dspy.LM(
|
| 30 |
+
f"{provider}/{model}",
|
| 31 |
+
api_key=api_key or os.getenv(f"{provider.upper()}_API_KEY"),
|
| 32 |
+
temperature=temperature, max_tokens=max_tokens
|
| 33 |
+
)
|
| 34 |
|
| 35 |
dspy.configure(lm=lm)
|
| 36 |
self.llm = lm
|
| 37 |
self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
|
| 38 |
self.temperature = temperature
|
| 39 |
+
|
| 40 |
def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
|
| 41 |
dataset_xml = context.get("dataset_xml")
|
| 42 |
readme_text = context.get("readme_text")
|
gradio-ui.py
CHANGED
|
@@ -36,6 +36,47 @@ STEP_LABELS = list(BIDSIFIER_STEPS.keys())
|
|
| 36 |
NUM_STEPS = len(STEP_LABELS)
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Helpers
|
| 40 |
|
| 41 |
def split_shell_commands(text: str) -> List[str]:
|
|
@@ -121,8 +162,33 @@ def build_context(
|
|
| 121 |
|
| 122 |
# Core callbacks
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def call_bidsifier_step(
|
| 125 |
-
|
|
|
|
| 126 |
dataset_xml: str,
|
| 127 |
readme_text: str,
|
| 128 |
publication_text: str,
|
|
@@ -137,6 +203,10 @@ def call_bidsifier_step(
|
|
| 137 |
|
| 138 |
Parameters
|
| 139 |
----------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
dataset_xml : str
|
| 141 |
Dataset XML content.
|
| 142 |
readme_text : str
|
|
@@ -185,7 +255,9 @@ def call_bidsifier_step(
|
|
| 185 |
step_id = BIDSIFIER_STEPS[step_label]
|
| 186 |
context = build_context(dataset_xml, readme_text, publication_text, output_root)
|
| 187 |
|
| 188 |
-
agent = BIDSifierAgent(
|
|
|
|
|
|
|
| 189 |
|
| 190 |
# Decide whether to use the structured step prompt or a free-form query:
|
| 191 |
if manual_prompt.strip():
|
|
@@ -204,7 +276,7 @@ def call_bidsifier_step(
|
|
| 204 |
step_index = 0
|
| 205 |
|
| 206 |
state = {
|
| 207 |
-
"
|
| 208 |
"dataset_xml": dataset_xml,
|
| 209 |
"readme_text": readme_text,
|
| 210 |
"publication_text": publication_text,
|
|
@@ -217,7 +289,9 @@ def call_bidsifier_step(
|
|
| 217 |
"commands": commands,
|
| 218 |
}
|
| 219 |
|
| 220 |
-
|
|
|
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
def confirm_commands(
|
|
@@ -288,7 +362,7 @@ def confirm_commands(
|
|
| 288 |
|
| 289 |
agent = BIDSifierAgent(
|
| 290 |
provider=last_state.get("provider", "openai"),
|
| 291 |
-
model=last_state.get("model", "
|
| 292 |
)
|
| 293 |
|
| 294 |
llm_output = agent.run_step(next_id, context)
|
|
@@ -479,13 +553,6 @@ with gr.Blocks(
|
|
| 479 |
"""
|
| 480 |
)
|
| 481 |
|
| 482 |
-
with gr.Row():
|
| 483 |
-
openai_key_input = gr.Textbox(
|
| 484 |
-
label="OpenAI API Key",
|
| 485 |
-
placeholder="Paste your OpenAI API key here",
|
| 486 |
-
lines=1,
|
| 487 |
-
type="password",
|
| 488 |
-
)
|
| 489 |
with gr.Row():
|
| 490 |
dataset_xml_file = gr.File(
|
| 491 |
label="Upload dataset_structure.xml (optional)",
|
|
@@ -509,17 +576,35 @@ with gr.Blocks(
|
|
| 509 |
lines=6,
|
| 510 |
)
|
| 511 |
|
| 512 |
-
with gr.Accordion("LLM settings
|
| 513 |
-
|
| 514 |
-
label="
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
placeholder="e.g., gpt-4o-mini, gpt-5",
|
| 522 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
output_root_input = gr.Textbox(
|
| 525 |
label="Output root",
|
|
@@ -554,10 +639,8 @@ with gr.Blocks(
|
|
| 554 |
|
| 555 |
call_button = gr.Button("Call BIDSifier", variant="primary")
|
| 556 |
|
| 557 |
-
llm_output_box = gr.
|
| 558 |
-
label="Raw BIDSifier output",
|
| 559 |
-
lines=10,
|
| 560 |
-
interactive=True,
|
| 561 |
)
|
| 562 |
|
| 563 |
commands_box = gr.Textbox(
|
|
@@ -578,10 +661,17 @@ with gr.Blocks(
|
|
| 578 |
|
| 579 |
# Wiring
|
| 580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
call_button.click(
|
| 582 |
fn=call_bidsifier_step,
|
| 583 |
inputs=[
|
| 584 |
-
|
|
|
|
| 585 |
dataset_xml_input,
|
| 586 |
readme_input,
|
| 587 |
publication_input,
|
|
|
|
| 36 |
NUM_STEPS = len(STEP_LABELS)
|
| 37 |
|
| 38 |
|
| 39 |
+
PROVIDERS_INFO = {
|
| 40 |
+
"openai": {
|
| 41 |
+
"model": "gpt-4o-mini",
|
| 42 |
+
"doc": "You can authenticate by setting the **API Key** above.",
|
| 43 |
+
},
|
| 44 |
+
"anthropic": {
|
| 45 |
+
"model": "claude-sonnet-4-5-20250929",
|
| 46 |
+
"doc": "You can authenticate by setting the **API Key** above.",
|
| 47 |
+
},
|
| 48 |
+
"databricks": {
|
| 49 |
+
"model": "databricks-llama-4-maverick",
|
| 50 |
+
"doc": (
|
| 51 |
+
"If you're on the Databricks platform, authentication is automatic"
|
| 52 |
+
" via their SDK. If not, you can set the **API Key** above."
|
| 53 |
+
),
|
| 54 |
+
},
|
| 55 |
+
"gemini": {
|
| 56 |
+
"model": "gemini-2.5-flash",
|
| 57 |
+
"doc": "You can authenticate by setting the **API Key** above.",
|
| 58 |
+
},
|
| 59 |
+
"ollama_chat": {
|
| 60 |
+
"model": "llama3.2:1b",
|
| 61 |
+
"doc": """
|
| 62 |
+
First, install Ollama and launch its server with your LM:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
> curl -fsSL https://ollama.ai/install.sh | sh
|
| 66 |
+
> ollama run llama3.2:1b
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
You can search all the Ollama models here: https://ollama.com/search
|
| 70 |
+
|
| 71 |
+
You can also manually input another provider. See the
|
| 72 |
+
[DSPy documentation](https://dspy.ai/api/models/LM/) for other providers and
|
| 73 |
+
local LLMs.
|
| 74 |
+
""",
|
| 75 |
+
"key_base": "http://localhost:11434",
|
| 76 |
+
},
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
# Helpers
|
| 81 |
|
| 82 |
def split_shell_commands(text: str) -> List[str]:
|
|
|
|
| 162 |
|
| 163 |
# Core callbacks
|
| 164 |
|
| 165 |
+
def select_provider(provider: str):
|
| 166 |
+
"""Fills the default provider model base, and description.
|
| 167 |
+
|
| 168 |
+
Parameters
|
| 169 |
+
----------
|
| 170 |
+
provider : str
|
| 171 |
+
The name of the provider.
|
| 172 |
+
|
| 173 |
+
Returns
|
| 174 |
+
-------
|
| 175 |
+
list
|
| 176 |
+
A list of two values: model name, and provider documentation in MD.
|
| 177 |
+
"""
|
| 178 |
+
provider = provider.lower()
|
| 179 |
+
if provider not in PROVIDERS_INFO:
|
| 180 |
+
return "", "", ""
|
| 181 |
+
provider_info = PROVIDERS_INFO[provider]
|
| 182 |
+
return (
|
| 183 |
+
provider_info.get("model", ""),
|
| 184 |
+
provider_info.get("key_base", ""),
|
| 185 |
+
provider_info.get("doc", ""),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
def call_bidsifier_step(
|
| 190 |
+
api_key: str,
|
| 191 |
+
api_base: str,
|
| 192 |
dataset_xml: str,
|
| 193 |
readme_text: str,
|
| 194 |
publication_text: str,
|
|
|
|
| 203 |
|
| 204 |
Parameters
|
| 205 |
----------
|
| 206 |
+
api_key : str
|
| 207 |
+
LLM provider API key.
|
| 208 |
+
api_base : str
|
| 209 |
+
Local LLM base URL.
|
| 210 |
dataset_xml : str
|
| 211 |
Dataset XML content.
|
| 212 |
readme_text : str
|
|
|
|
| 255 |
step_id = BIDSIFIER_STEPS[step_label]
|
| 256 |
context = build_context(dataset_xml, readme_text, publication_text, output_root)
|
| 257 |
|
| 258 |
+
agent = BIDSifierAgent(
|
| 259 |
+
provider=provider, model=model, api_key=api_key, api_base=api_base
|
| 260 |
+
)
|
| 261 |
|
| 262 |
# Decide whether to use the structured step prompt or a free-form query:
|
| 263 |
if manual_prompt.strip():
|
|
|
|
| 276 |
step_index = 0
|
| 277 |
|
| 278 |
state = {
|
| 279 |
+
"api_key": api_key,
|
| 280 |
"dataset_xml": dataset_xml,
|
| 281 |
"readme_text": readme_text,
|
| 282 |
"publication_text": publication_text,
|
|
|
|
| 289 |
"commands": commands,
|
| 290 |
}
|
| 291 |
|
| 292 |
+
llm_output_text = "## Raw BIDSifier output\n" + llm_output
|
| 293 |
+
|
| 294 |
+
return llm_output_text, commands_str, state, step_index
|
| 295 |
|
| 296 |
|
| 297 |
def confirm_commands(
|
|
|
|
| 362 |
|
| 363 |
agent = BIDSifierAgent(
|
| 364 |
provider=last_state.get("provider", "openai"),
|
| 365 |
+
model=last_state.get("model", PROVIDERS_INFO["openai"]["model"]),
|
| 366 |
)
|
| 367 |
|
| 368 |
llm_output = agent.run_step(next_id, context)
|
|
|
|
| 553 |
"""
|
| 554 |
)
|
| 555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
with gr.Row():
|
| 557 |
dataset_xml_file = gr.File(
|
| 558 |
label="Upload dataset_structure.xml (optional)",
|
|
|
|
| 576 |
lines=6,
|
| 577 |
)
|
| 578 |
|
| 579 |
+
with gr.Accordion("LLM settings"):
|
| 580 |
+
api_key_input = gr.Textbox(
|
| 581 |
+
label="API Key (OpenAI, Gemini, Anthropic, Databricks)",
|
| 582 |
+
placeholder=(
|
| 583 |
+
"Paste your LLM provider API key here."
|
| 584 |
+
' See "Advanced settings" to set up non-OpenAI providers.'
|
| 585 |
+
),
|
| 586 |
+
lines=1,
|
| 587 |
+
type="password",
|
|
|
|
| 588 |
)
|
| 589 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 590 |
+
default_provider = list(PROVIDERS_INFO.keys())[0]
|
| 591 |
+
provider_input = gr.Dropdown(
|
| 592 |
+
label="Provider",
|
| 593 |
+
choices=PROVIDERS_INFO.keys(),
|
| 594 |
+
value=default_provider,
|
| 595 |
+
interactive=True,
|
| 596 |
+
)
|
| 597 |
+
model_input = gr.Textbox(
|
| 598 |
+
label="Model",
|
| 599 |
+
value=PROVIDERS_INFO[default_provider]["model"],
|
| 600 |
+
placeholder="e.g., gpt-4o-mini, gpt-5",
|
| 601 |
+
)
|
| 602 |
+
api_base_input = gr.Textbox(
|
| 603 |
+
label="API Base (Local LLMs, other providers)",
|
| 604 |
+
placeholder="Write your API Base URL here",
|
| 605 |
+
lines=1,
|
| 606 |
+
)
|
| 607 |
+
model_doc = gr.Markdown(PROVIDERS_INFO[default_provider]["doc"])
|
| 608 |
|
| 609 |
output_root_input = gr.Textbox(
|
| 610 |
label="Output root",
|
|
|
|
| 639 |
|
| 640 |
call_button = gr.Button("Call BIDSifier", variant="primary")
|
| 641 |
|
| 642 |
+
llm_output_box = gr.Markdown(
|
| 643 |
+
label="Raw BIDSifier output", show_label=True, container=True
|
|
|
|
|
|
|
| 644 |
)
|
| 645 |
|
| 646 |
commands_box = gr.Textbox(
|
|
|
|
| 661 |
|
| 662 |
# Wiring
|
| 663 |
|
| 664 |
+
provider_input.change(
|
| 665 |
+
fn=select_provider,
|
| 666 |
+
inputs=provider_input,
|
| 667 |
+
outputs=[model_input, api_base_input, model_doc],
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
call_button.click(
|
| 671 |
fn=call_bidsifier_step,
|
| 672 |
inputs=[
|
| 673 |
+
api_key_input,
|
| 674 |
+
api_base_input,
|
| 675 |
dataset_xml_input,
|
| 676 |
readme_input,
|
| 677 |
publication_input,
|