Spaces:
Sleeping
Sleeping
File size: 2,631 Bytes
4379b25 7428353 8d0fd38 7428353 9887d83 4379b25 0f31aa7 4379b25 8d0fd38 0f31aa7 8d0fd38 0f31aa7 9d1b729 0f31aa7 9d1b729 8d0fd38 4379b25 0f31aa7 4379b25 9887d83 4379b25 9887d83 4379b25 9887d83 4379b25 9887d83 4379b25 9887d83 4379b25 8d0fd38 4379b25 8d0fd38 4379b25 8d0fd38 4379b25 9887d83 4379b25 7428353 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from typing import Optional, Dict, Any
import os
from dotenv import load_dotenv
import dspy
import prompts
class BIDSifierAgent:
"""Wrapper around OpenAI chat API for step-wise BIDSification."""
def __init__(
self, *,
provider: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
temperature: Optional[float | None] = 0.2,
):
load_dotenv()
if model.startswith("gpt-5"): # reasoning model that requires special handling
temperature = 1.0
max_tokens = 40000
else:
temperature = None
max_tokens = 10000
lm = dspy.LM(
f"{provider}/{model}",
api_key=api_key or os.getenv(f"{provider.upper()}_API_KEY"),
temperature=temperature, max_tokens=max_tokens
)
dspy.configure(lm=lm)
self.llm = lm
self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
self.temperature = temperature
def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
dataset_xml = context.get("dataset_xml")
readme_text = context.get("readme_text")
publication_text = context.get("publication_text")
output_root = context.get("output_root", "./bids_output")
if step == "summary":
return prompts.summarize_dataset_prompt(
dataset_xml=dataset_xml,
readme_text=readme_text,
publication_text=publication_text,
)
if step == "create_metadata":
return prompts.create_metadata_prompt(
output_root=output_root,
dataset_xml=dataset_xml,
readme_text=readme_text,
publication_text=publication_text,
)
if step == "create_structure":
return prompts.create_structure_prompt(
output_root=output_root,
dataset_xml=dataset_xml,
readme_text=readme_text,
publication_text=publication_text,
)
if step == "rename_move":
return prompts.rename_and_move_prompt(
output_root=output_root,
dataset_xml=dataset_xml,
readme_text=readme_text,
publication_text=publication_text,
)
raise ValueError(f"Unknown step: {step}")
def run_step(self, step: str, context: Dict[str, Any]) -> str:
system_msg = prompts.system_prompt()
user_msg = self._build_user_prompt(step, context)
resp = self.llm(
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
],
temperature=self.temperature,
)
return resp[0]
def run_query(self, query: str) -> str:
system_msg = prompts.system_prompt()
resp = self.llm(
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": query},
],
temperature=self.temperature,
)
return resp[0]
__all__ = ["BIDSifierAgent"]
|