|
|
import copy |
|
|
import asyncio |
|
|
from typing import Callable, Dict, Union, Awaitable |
|
|
from pydantic import Field |
|
|
import dspy |
|
|
from ...optimizers.engine.registry import ParamRegistry |
|
|
from typing import List |
|
|
|
|
|
from ...core.logging import logger |
|
|
from ...prompts.template import PromptTemplate |
|
|
from dspy.utils.saving import get_dependency_versions |
|
|
from pathlib import Path |
|
|
import cloudpickle |
|
|
import ujson |
|
|
|
|
|
|
|
|
class PromptTuningModule(dspy.Module): |
|
|
""" |
|
|
A prompt tuning module that manages interactions between predictors, |
|
|
parameter registry, and program functions. |
|
|
|
|
|
This module coordinates prompt optimization through: |
|
|
1. Maintaining a set of predictors for different tasks |
|
|
2. Synchronizing optimized parameters back to the program |
|
|
3. Executing the program with updated parameters |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
|
|
The main program function to execute. Can be either synchronous or asynchronous. |
|
|
Must return a dictionary containing execution results. |
|
|
signature_dict : Dict[str, dspy.Signature] |
|
|
A mapping of task names to their corresponding DSPy signatures. |
|
|
Each signature defines the input/output structure for a specific task. |
|
|
registry : ParamRegistry |
|
|
A registry that maintains tunable parameters shared between |
|
|
predictors and the program. |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def from_registry( |
|
|
cls, |
|
|
program: Union[Callable[..., dict], Callable[..., Awaitable[dict]]], |
|
|
registry: ParamRegistry, |
|
|
) -> "PromptTuningModule": |
|
|
""" |
|
|
Factory method to create a PromptTuningModule from a registry and program. |
|
|
|
|
|
This method: |
|
|
1. Creates signatures for each field in the registry |
|
|
2. Initializes a PromptTuningModule with the program and signatures |
|
|
3. Sets up predictors for each signature |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
|
|
The main program function to execute |
|
|
registry : ParamRegistry |
|
|
Registry containing tunable parameters |
|
|
|
|
|
Returns |
|
|
------- |
|
|
PromptTuningModule |
|
|
A configured PromptTuningModule instance |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> registry = ParamRegistry() |
|
|
>>> registry.register("task1", "What is {topic}?") |
|
|
>>> registry.register("task2", PromptTemplate(system="You are helpful.", user="{query}")) |
|
|
>>> def my_program(**kwargs) -> dict: |
|
|
... return {"result": "done"} |
|
|
>>> module = PromptTuningModule.from_registry(my_program, registry) |
|
|
""" |
|
|
from .signature_utils import signature_from_registry |
|
|
|
|
|
|
|
|
signature_dict, signature_name2register_name = signature_from_registry( |
|
|
registry=registry, |
|
|
) |
|
|
|
|
|
|
|
|
return cls(program=program, signature_dict=signature_dict, registry=registry, signature_name2register_name=signature_name2register_name) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
program: Union[Callable[..., dict], Callable[..., Awaitable[dict]]], |
|
|
signature_dict: Dict[str, dspy.Signature], |
|
|
registry: ParamRegistry, |
|
|
signature_name2register_name: Dict[str, str], |
|
|
): |
|
|
""" |
|
|
Initialize a PromptTuningModule instance. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
|
|
The main program function to execute |
|
|
signature_dict : Dict[str, dspy.Signature] |
|
|
Mapping of task names to signatures |
|
|
registry : ParamRegistry |
|
|
Parameter registry |
|
|
signature_name2register_name : Dict[str, str] |
|
|
Mapping of signature names to register names |
|
|
""" |
|
|
super().__init__() |
|
|
self.program = program |
|
|
self.predicts = [] |
|
|
|
|
|
seen = set() |
|
|
for name, signature in signature_dict.items(): |
|
|
if name in seen: |
|
|
raise ValueError(f"Duplicate name {name} in signature_dict") |
|
|
seen.add(name) |
|
|
self.predicts.append(dspy.Predict(signature, name=name)) |
|
|
self.registry = registry |
|
|
self.signature_name2register_name = signature_name2register_name |
|
|
|
|
|
def reset(self): |
|
|
""" |
|
|
Reset the module to its initial state. |
|
|
""" |
|
|
self.registry.reset() |
|
|
|
|
|
for predict in self.predicts: |
|
|
signature = predict.signature |
|
|
|
|
|
signature_name = signature.__name__ |
|
|
register_name = self.signature_name2register_name[signature_name] |
|
|
|
|
|
register_element = self.registry.get(register_name) |
|
|
|
|
|
if isinstance(register_element, PromptTemplate): |
|
|
predict.signature.instructions = register_element.instruction |
|
|
predict.demos = register_element.demonstrations |
|
|
elif isinstance(register_element, str): |
|
|
predict.signature.instructions = register_element |
|
|
predict.demos = [] |
|
|
else: |
|
|
logger.warning(f"Unsupported register element type: {type(register_element)}") |
|
|
|
|
|
|
|
|
return self |
|
|
|
|
|
def escape_braces(self, text): |
|
|
""" |
|
|
Escape all braces in the text. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
text : str |
|
|
Text that needs escaping |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
Escaped text |
|
|
""" |
|
|
def helper(s, start=0): |
|
|
result = '' |
|
|
i = start |
|
|
while i < len(s): |
|
|
if s[i] == '{': |
|
|
inner, new_i = helper(s, i + 1) |
|
|
result += '{{' + inner + '}}' |
|
|
i = new_i |
|
|
elif s[i] == '}': |
|
|
return result, i + 1 |
|
|
else: |
|
|
result += s[i] |
|
|
i += 1 |
|
|
return result, i |
|
|
|
|
|
escaped, _ = helper(text) |
|
|
return escaped |
|
|
|
|
|
def _validate_prompt(self, prompt: str, input_names: List[str], verbose: bool = True) -> str: |
|
|
""" |
|
|
Validate if the generated prompt is valid. Currently only checks if required inputs are wrapped in braces. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
prompt : str |
|
|
The prompt to validate |
|
|
input_names : List[str] |
|
|
List of required input names |
|
|
verbose : bool, optional |
|
|
Whether to show detailed information, defaults to True |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
Validated and potentially modified prompt |
|
|
""" |
|
|
modified_messages = [] |
|
|
required_inputs = input_names |
|
|
missing_required_inputs = [name for name in required_inputs if f"{{{name}}}" not in prompt] |
|
|
if missing_required_inputs: |
|
|
input_values = "\n\n".join([f"{name}: {{{name}}}" for name in missing_required_inputs]) |
|
|
prompt += f"\n\nThe followings are some required input values: \n{input_values}" |
|
|
modified_messages.append(f"added missing inputs: {', '.join(missing_required_inputs)}") |
|
|
|
|
|
prompt = self.escape_braces(prompt) |
|
|
for name in input_names: |
|
|
prompt = prompt.replace(f"{{{{{name}}}}}", f"{{{name}}}") |
|
|
prompt = prompt.replace(r"{{{{", r"{{").replace(r"}}}}", r"}}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prompt |
|
|
|
|
|
def get_field_type(self, field: Field) -> str: |
|
|
""" |
|
|
Get the type of the field. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
field : Field |
|
|
The field to get type from |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
The field type |
|
|
""" |
|
|
return field.json_schema_extra.get('__dspy_field_type') if field.json_schema_extra.get('__dspy_field_type') else None |
|
|
|
|
|
def is_prompt_template(self, register_name: str) -> bool: |
|
|
""" |
|
|
Check if the register name is a prompt template. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
register_name : str |
|
|
The register name to check |
|
|
|
|
|
Returns |
|
|
------- |
|
|
bool |
|
|
Whether it is a prompt template |
|
|
""" |
|
|
return self.registry.get(register_name) is not None and isinstance(self.registry.get(register_name), PromptTemplate) |
|
|
|
|
|
def get_demos(self, demos: list) -> List[dict]: |
|
|
result = [] |
|
|
for demo in demos: |
|
|
if isinstance(demo, dspy.Example): |
|
|
demo = demo.toDict() |
|
|
result.append(demo) |
|
|
return result |
|
|
|
|
|
def _inject_demos_to_string(self, instruction: str, demos: List[dict], input_names: List[str], output_names: List[str]) -> str: |
|
|
""" |
|
|
Inject demos to the instruction. |
|
|
""" |
|
|
if not demos: |
|
|
return instruction |
|
|
|
|
|
def _escape_braces(text: str) -> str: |
|
|
return text.replace("{", "{{").replace("}", "}}") |
|
|
|
|
|
def format_demo(demo: dict) -> str: |
|
|
demo_str = "Inputs:\n" |
|
|
inputs = {name: demo.get(name, "Not provided") for name in input_names} |
|
|
demo_str += "\n".join([f"{name}:\n{_escape_braces(str(value))}" for name, value in inputs.items()]) |
|
|
demo_str += "\n\nOutputs:\n" |
|
|
outputs = {name: demo.get(name, "Not provided") for name in output_names} |
|
|
demo_str += "\n".join([f"{name}:\n{_escape_braces(str(value))}" for name, value in outputs.items()]) |
|
|
return demo_str |
|
|
|
|
|
demos_string = "\n\n".join([f"Example {i+1}:\n{format_demo(demo)}" for i, demo in enumerate(demos)]) |
|
|
prompt = f"{instruction}\n\nThe following are some examples:\n{demos_string}" |
|
|
return prompt |
|
|
|
|
|
def sync_predict_inputs_to_program(self): |
|
|
""" |
|
|
Synchronize current input values from all predictors back to the registry. |
|
|
|
|
|
This method ensures that any optimized parameters in the predictors' configurations |
|
|
are properly reflected in the registry, which in turn affects program execution. |
|
|
|
|
|
Synchronization process: |
|
|
1. Iterate through all predictors |
|
|
2. For each predictor, check its signature's input fields |
|
|
3. If a field has a value in the predictor's config, update the registry |
|
|
|
|
|
Note: Values in predictor configs take precedence as they may contain |
|
|
optimized values from recent tuning iterations. |
|
|
""" |
|
|
for predict in self.predicts: |
|
|
signature = predict.signature |
|
|
instruction = signature.instructions |
|
|
demos = predict.demos |
|
|
|
|
|
input_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'input'] |
|
|
output_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'output'] |
|
|
|
|
|
signature_name = signature.__name__ |
|
|
register_name = self.signature_name2register_name[signature_name] |
|
|
|
|
|
if self.is_prompt_template(register_name): |
|
|
prompt_template: PromptTemplate = self.registry.get(register_name) |
|
|
prompt_template.instruction = instruction |
|
|
prompt_template.demonstrations = self.get_demos(demos) |
|
|
self.registry.set(register_name, prompt_template) |
|
|
else: |
|
|
instruction = self._validate_prompt(instruction, input_names) |
|
|
|
|
|
prompt = self._inject_demos_to_string(instruction, self.get_demos(demos), input_names, output_names) |
|
|
self.registry.set(register_name, prompt) |
|
|
|
|
|
def constrcut_trace(self, execution_data: dict) -> dict: |
|
|
""" |
|
|
Construct the trace of the execution. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
execution_data : dict |
|
|
Execution data |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict |
|
|
Trace information |
|
|
""" |
|
|
trace: List[dict] = [] |
|
|
for predict in self.predicts: |
|
|
input_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'input'] |
|
|
output_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'output'] |
|
|
|
|
|
input_dict = {} |
|
|
output_dict = {} |
|
|
|
|
|
|
|
|
for name in input_names: |
|
|
if name not in execution_data: |
|
|
logger.warning(f"Input {name} not found in execution data") |
|
|
for name in output_names: |
|
|
if name not in execution_data: |
|
|
logger.warning(f"Output {name} not found in execution data") |
|
|
|
|
|
|
|
|
for name in input_names: |
|
|
if name in execution_data: |
|
|
input_dict[name] = execution_data[name] |
|
|
for name in output_names: |
|
|
if name in execution_data: |
|
|
output_dict[name] = execution_data[name] |
|
|
|
|
|
trace_tuple = (predict, input_dict, output_dict) |
|
|
trace.append(trace_tuple) |
|
|
return trace |
|
|
|
|
|
def forward(self, **kwargs) -> dict: |
|
|
""" |
|
|
Execute the program with synchronized parameters and optional inputs. |
|
|
|
|
|
This method: |
|
|
1. Synchronizes optimized prompts back to the program via registry |
|
|
2. Executes the program (handles both sync and async functions) |
|
|
3. Validates and returns the program's output |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
**kwargs : dict |
|
|
Optional keyword arguments to pass to the program function |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict |
|
|
The program's execution results |
|
|
|
|
|
Raises |
|
|
------ |
|
|
ValueError |
|
|
If the program doesn't return a dictionary |
|
|
""" |
|
|
|
|
|
self.sync_predict_inputs_to_program() |
|
|
|
|
|
|
|
|
if asyncio.iscoroutinefunction(self.program): |
|
|
output, execution_data = asyncio.run(self.program(**kwargs)) if kwargs else asyncio.run(self.program()) |
|
|
else: |
|
|
output, execution_data = self.program(**kwargs) if kwargs else self.program() |
|
|
|
|
|
trace = self.constrcut_trace(execution_data) |
|
|
|
|
|
|
|
|
if dspy.settings.trace is not None: |
|
|
dspy_trace = dspy.settings.trace |
|
|
dspy_trace.extend(trace) |
|
|
|
|
|
return output |
|
|
|
|
|
def deepcopy(self): |
|
|
""" |
|
|
Deep copy the module. |
|
|
|
|
|
This is a tweak to the default Python deepcopy that only deep copies `self.parameters()`, |
|
|
and for other attributes, we just do a shallow copy. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
PromptTuningModule |
|
|
A deep copy of the module |
|
|
""" |
|
|
try: |
|
|
|
|
|
new_instance = copy.deepcopy(self) |
|
|
setattr(new_instance, "program", self.program) |
|
|
return new_instance |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
new_instance = self.__class__.__new__(self.__class__) |
|
|
|
|
|
for attr, value in self.__dict__.items(): |
|
|
if isinstance(value, dspy.Module): |
|
|
setattr(new_instance, attr, value.deepcopy()) |
|
|
else: |
|
|
try: |
|
|
|
|
|
setattr(new_instance, attr, copy.deepcopy(value)) |
|
|
except Exception: |
|
|
try: |
|
|
|
|
|
setattr(new_instance, attr, copy.copy(value)) |
|
|
except Exception: |
|
|
|
|
|
setattr(new_instance, attr, value) |
|
|
|
|
|
|
|
|
setattr(new_instance, "program", self.program) |
|
|
return new_instance |
|
|
|
|
|
def save(self, path, save_program=False): |
|
|
"""Save the module. |
|
|
|
|
|
Save the module to a directory or a file. There are two modes: |
|
|
- `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of |
|
|
the file extension. |
|
|
- `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and |
|
|
architecture of the model. |
|
|
|
|
|
We also save the dependency versions, so that the loaded model can check if there is a version mismatch on |
|
|
critical dependencies or DSPy version. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`, |
|
|
and a directory when `save_program=True`. |
|
|
save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save |
|
|
the state. |
|
|
""" |
|
|
|
|
|
metadata = {} |
|
|
metadata["dependency_versions"] = get_dependency_versions() |
|
|
path = Path(path) |
|
|
|
|
|
if not path.is_dir(): |
|
|
|
|
|
if not path.parent.exists(): |
|
|
path.parent.mkdir(parents=True) |
|
|
else: |
|
|
|
|
|
if not path.exists(): |
|
|
|
|
|
if not path.exists(): |
|
|
|
|
|
path.mkdir(parents=True) |
|
|
|
|
|
if hasattr(self.program, "save"): |
|
|
self.program.save(str(path)) |
|
|
return |
|
|
|
|
|
if save_program: |
|
|
if path.suffix: |
|
|
raise ValueError( |
|
|
f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}" |
|
|
) |
|
|
if path.exists() and not path.is_dir(): |
|
|
raise NotADirectoryError(f"The path '{path}' exists but is not a directory.") |
|
|
|
|
|
try: |
|
|
with open(path / "program.pkl", "wb") as f: |
|
|
cloudpickle.dump(self, f) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, " |
|
|
"or consider using state-only saving by setting `save_program=False`." |
|
|
) |
|
|
with open(path / "metadata.json", "w") as f: |
|
|
ujson.dump(metadata, f, indent=4) |
|
|
|
|
|
return |
|
|
|
|
|
state = self.dump_state() |
|
|
state["metadata"] = metadata |
|
|
if path.suffix == ".json": |
|
|
try: |
|
|
with open(path, "w") as f: |
|
|
f.write(ujson.dumps(state, indent=4)) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non " |
|
|
"json-serializable objects, please consider saving the state in .pkl by using `path` ending " |
|
|
"with `.pkl`, or saving the whole program by setting `save_program=True`." |
|
|
) |
|
|
elif path.suffix == ".pkl": |
|
|
with open(path, "wb") as f: |
|
|
cloudpickle.dump(state, f) |
|
|
else: |
|
|
raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}") |
|
|
|
|
|
def load(self, path): |
|
|
"""Load the saved module. You may also want to check out dspy.load, if you want to |
|
|
load an entire program, not just the state for an existing program. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the saved state file, which should be a .json or a .pkl file |
|
|
""" |
|
|
path = Path(path) |
|
|
|
|
|
if hasattr(self.program, "load"): |
|
|
self.program.load(str(path)) |
|
|
|
|
|
return |
|
|
|
|
|
if path.suffix == ".json": |
|
|
with open(path) as f: |
|
|
state = ujson.loads(f.read()) |
|
|
elif path.suffix == ".pkl": |
|
|
with open(path, "rb") as f: |
|
|
state = cloudpickle.load(f) |
|
|
else: |
|
|
raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}") |
|
|
|
|
|
dependency_versions = get_dependency_versions() |
|
|
saved_dependency_versions = state["metadata"]["dependency_versions"] |
|
|
for key, saved_version in saved_dependency_versions.items(): |
|
|
if dependency_versions[key] != saved_version: |
|
|
logger.warning( |
|
|
f"There is a mismatch of {key} version between saved model and current environment. " |
|
|
f"You saved with `{key}=={saved_version}`, but now you have " |
|
|
f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade " |
|
|
"on the loaded model, please consider loading the model in the same environment as the " |
|
|
"saving environment." |
|
|
) |
|
|
self.load_state(state) |
|
|
self.sync_predict_inputs_to_program() |
|
|
|
|
|
|