File size: 2,631 Bytes
4379b25
7428353
 
8d0fd38
7428353
9887d83
4379b25
 
 
 
 
0f31aa7
 
 
 
 
 
 
 
4379b25
8d0fd38
0f31aa7
 
 
8d0fd38
0f31aa7
 
9d1b729
0f31aa7
 
 
 
 
9d1b729
8d0fd38
 
4379b25
 
0f31aa7
4379b25
 
 
 
 
 
 
9887d83
4379b25
 
 
 
 
9887d83
4379b25
 
 
 
 
 
9887d83
4379b25
 
 
 
 
 
9887d83
4379b25
 
 
 
 
 
 
 
9887d83
4379b25
8d0fd38
4379b25
 
 
 
8d0fd38
4379b25
8d0fd38
4379b25
9887d83
 
 
 
 
 
 
 
 
 
 
4379b25
 
7428353
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Optional, Dict, Any
import os
from dotenv import load_dotenv
import dspy

import prompts


class BIDSifierAgent:
	"""Wrapper around OpenAI chat API for step-wise BIDSification."""

	def __init__(
		self, *,
        provider: str,
        model: str,
        api_key: Optional[str] = None,
		api_base: Optional[str] = None,
        temperature: Optional[float | None] = 0.2,
	):
		load_dotenv()
		
		if model.startswith("gpt-5"):  # reasoning model that requires special handling
			temperature = 1.0
			max_tokens = 40000
		else:
			temperature = None
			max_tokens = 10000

		lm = dspy.LM(
			f"{provider}/{model}",
			api_key=api_key or os.getenv(f"{provider.upper()}_API_KEY"),
			temperature=temperature, max_tokens=max_tokens
		)

		dspy.configure(lm=lm)
		self.llm = lm
		self.model = model or os.getenv("BIDSIFIER_MODEL", "gpt-4o-mini")
		self.temperature = temperature

	def _build_user_prompt(self, step: str, context: Dict[str, Any]) -> str:
		dataset_xml = context.get("dataset_xml")
		readme_text = context.get("readme_text")
		publication_text = context.get("publication_text")
		output_root = context.get("output_root", "./bids_output")

		if step == "summary":
			return prompts.summarize_dataset_prompt(
				dataset_xml=dataset_xml,
				readme_text=readme_text,
				publication_text=publication_text,
			)
		if step == "create_metadata":
			return prompts.create_metadata_prompt(
				output_root=output_root,
				dataset_xml=dataset_xml,
				readme_text=readme_text,
				publication_text=publication_text,
			)
		if step == "create_structure":
			return prompts.create_structure_prompt(
				output_root=output_root,
				dataset_xml=dataset_xml,
				readme_text=readme_text,
				publication_text=publication_text,
			)
		if step == "rename_move":
			return prompts.rename_and_move_prompt(
				output_root=output_root,
				dataset_xml=dataset_xml,
				readme_text=readme_text,
				publication_text=publication_text,
			)
		raise ValueError(f"Unknown step: {step}")

	def run_step(self, step: str, context: Dict[str, Any]) -> str:
		system_msg = prompts.system_prompt()
		user_msg = self._build_user_prompt(step, context)
		resp = self.llm(
			messages=[
				{"role": "system", "content": system_msg},
				{"role": "user", "content": user_msg},
			],
			temperature=self.temperature,
		)
		return resp[0]

	def run_query(self, query: str) -> str:
		system_msg = prompts.system_prompt()
		resp = self.llm(
			messages=[
				{"role": "system", "content": system_msg},
				{"role": "user", "content": query},
			],
			temperature=self.temperature,
		)
		return resp[0]


__all__ = ["BIDSifierAgent"]