Elron commited on
Commit
a62442d
·
verified ·
1 Parent(s): c160aec

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. inference.py +2 -0
  2. metric_utils.py +3 -1
  3. metrics.py +380 -3
  4. operators.py +103 -0
  5. splitters.py +4 -1
  6. text2sql_utils.py +3 -1
  7. version.py +1 -1
inference.py CHANGED
@@ -3268,6 +3268,7 @@ class CrossProviderInferenceEngine(
3268
 
3269
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3270
  "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
 
3271
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3272
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
3273
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
@@ -3290,6 +3291,7 @@ class CrossProviderInferenceEngine(
3290
  "llama-3-2-3b-instruct": "meta-llama/llama-3-2-3b-instruct",
3291
  "llama-3-2-90b-vision-instruct": "meta-llama/llama-3-2-90b-vision-instruct",
3292
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
 
3293
  "llama-guard-3-11b-vision": "meta-llama/llama-guard-3-11b-vision",
3294
  "mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
3295
  "mistral-small-instruct": "mistralai/mistral-small-3-1-24b-instruct-2503",
 
3268
 
3269
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3270
  "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3271
+ "gpt-oss-120b": "openai/gpt-oss-120b",
3272
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3273
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
3274
  "granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
 
3291
  "llama-3-2-3b-instruct": "meta-llama/llama-3-2-3b-instruct",
3292
  "llama-3-2-90b-vision-instruct": "meta-llama/llama-3-2-90b-vision-instruct",
3293
  "llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
3294
+ "llama-4-maverick": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
3295
  "llama-guard-3-11b-vision": "meta-llama/llama-guard-3-11b-vision",
3296
  "mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
3297
  "mistral-small-instruct": "mistralai/mistral-small-3-1-24b-instruct-2503",
metric_utils.py CHANGED
@@ -733,7 +733,9 @@ class InstanceScores(list):
733
  ).head()
734
  df["score_name"] = df["score"].apply(lambda x: x["instance"]["score_name"])
735
  df["all_scores"] = df["score"].apply(
736
- lambda x: "\n".join(f"{k}: {v}" for k, v in x["instance"].items())
 
 
737
  )
738
  df["score"] = df["score"].apply(lambda x: x["instance"]["score"])
739
 
 
733
  ).head()
734
  df["score_name"] = df["score"].apply(lambda x: x["instance"]["score_name"])
735
  df["all_scores"] = df["score"].apply(
736
+ lambda x: "\n".join(
737
+ f"{k}: {v}" for k, v in x["instance"].items() if isoftype(v, float)
738
+ )
739
  )
740
  df["score"] = df["score"].apply(lambda x: x["instance"]["score"])
741
 
metrics.py CHANGED
@@ -1,4 +1,5 @@
1
  import ast
 
2
  import json
3
  import math
4
  import os
@@ -17,6 +18,7 @@ from typing import (
17
  Dict,
18
  Generator,
19
  Generic,
 
20
  List,
21
  Literal,
22
  Optional,
@@ -891,6 +893,375 @@ class MultiTurnToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]])
891
  }
892
 
893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  class MetricWithConfidenceInterval(Metric):
895
  # The number of resamples used to estimate the confidence intervals of this metric.
896
  # Use None to disable confidence interval computation.
@@ -4889,12 +5260,11 @@ class FaithfulnessHHEM(BulkInstanceMetric):
4889
  # single_reference_per_prediction = True
4890
  max_context_words = 4096
4891
  reduction_map = {"mean": [main_score]}
 
4892
 
4893
  _requirements_list: List[str] = ["transformers", "torch"]
4894
 
4895
- @retry_connection_with_exponential_backoff(backoff_factor=2)
4896
- def prepare(self):
4897
- super().prepare()
4898
  import torch
4899
 
4900
  if torch.cuda.is_available():
@@ -4912,6 +5282,11 @@ class FaithfulnessHHEM(BulkInstanceMetric):
4912
  model_path, trust_remote_code=True
4913
  ).to(device)
4914
 
 
 
 
 
 
4915
  def compute(
4916
  self,
4917
  references: List[List[Any]],
@@ -4920,6 +5295,8 @@ class FaithfulnessHHEM(BulkInstanceMetric):
4920
  ) -> List[Dict[str, Any]]:
4921
  from tqdm import tqdm
4922
 
 
 
4923
  # treat the references as the contexts and the predictions as answers
4924
  # concat references
4925
 
 
1
  import ast
2
+ import asyncio
3
  import json
4
  import math
5
  import os
 
18
  Dict,
19
  Generator,
20
  Generic,
21
+ Iterable,
22
  List,
23
  Literal,
24
  Optional,
 
893
  }
894
 
895
 
896
+ class ReflectionToolCallingMixin:
897
+ @staticmethod
898
+ def convert_tools_inventory(tools):
899
+ from llmevalkit.function_calling.pipeline.types import (
900
+ ToolSpec as LLMEvalKitToolSpec,
901
+ )
902
+
903
+ return [
904
+ LLMEvalKitToolSpec(
905
+ type="function",
906
+ function={**tool},
907
+ )
908
+ for tool in tools
909
+ ]
910
+
911
+ @staticmethod
912
+ def convert_tool_call(prediction: ToolCall):
913
+ from llmevalkit.function_calling.pipeline.types import (
914
+ ToolCall as LLMEvalKitToolCall,
915
+ )
916
+
917
+ return LLMEvalKitToolCall(
918
+ type="function",
919
+ function={
920
+ "name": prediction["name"],
921
+ "arguments": json.dumps(prediction["arguments"]),
922
+ "parsed_arguments": prediction["arguments"],
923
+ },
924
+ )
925
+
926
+
927
+ class ReflectionToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
928
+ """Measures syntactic and semantic validity of tool calls.
929
+
930
+ The final output contains two main fields: "semantic" and "static" (i.e., semantic).
931
+ Under the semantics we define two types of metrics: general and function selection.
932
+
933
+ General metrics evaluate the overall quality and correctness of the tool call.
934
+ These metrics contains:
935
+ 1. General hallucination check: Evaluate whether each parameter value in the function call is correct and directly supported by the provided conversation history and adhere the tool specifications.
936
+ 2. Value format alignment: Check if the format of the parameter values aligns with the expected formats defined in the tool specifications.
937
+
938
+ Function selection metrics evaluate the appropriateness of the selected function for the given context.
939
+ These metrics include:
940
+ 1. Function selection appropriateness: Assess whether the chosen function is suitable for the task at hand.
941
+ 2. Agentic constraints satisfaction: Assess whether the proposed tool call satisfies all agentic constraints required for execution.
942
+
943
+ Static metrics evaluate the syntactic validity of the tool call.
944
+ It contains the following metrics:
945
+ - non_existent_function: tool name not found.
946
+ - non_existent_parameter: argument name not in tool spec.
947
+ - incorrect_parameter_type: argument type mismatch.
948
+ - missing_required_parameter: required argument missing.
949
+ - allowed_values_violation: argument value outside allowed set.
950
+ - json_schema_violation: call violates JSON schema.
951
+ - empty_api_spec: no tool spec provided.
952
+ - invalid_api_spec: tool spec is invalid.
953
+ - invalid_tool_call: call is not a valid tool invocation.
954
+ - overall_valid: validity of the call (main score).
955
+ - score: alias of overall_valid.
956
+
957
+ Here is an example for a aggregated reflection output after calling reduce.
958
+ The range of each score is [0, 1] (where higher indicates less errors).
959
+ {
960
+ "static_non_existent_function": 1.0,
961
+ "static_non_existent_parameter": 1.0,
962
+ "static_incorrect_parameter_type": 1.0,
963
+ "static_missing_required_parameter": 1.0,
964
+ "static_allowed_values_violation": 1.0,
965
+ "static_json_schema_violation": 1.0,
966
+ "static_empty_api_spec": 1.0,
967
+ "static_invalid_api_spec": 1.0,
968
+ "static_invalid_tool_call": 1.0,
969
+ "semantic_general_hallucination_check": 0.0,
970
+ "semantic_general_value_format_alignment": 0.0,
971
+ "semantic_avg_score_general": 1.0,
972
+ "semantic_function_selection_appropriateness": 0.0,
973
+ "semantic_agentic_constraints_satisfaction": 0.0,
974
+ "semantic_avg_score_function_selection": 1.0,
975
+ "overall_valid": 1.0
976
+ }
977
+
978
+ Where overall_valid is the final decision made by the reflection pipeline, indicating whether the tool call is valid or not.
979
+
980
+ Before the aggregation each metric contains also evidence, explanation, a more fine-grained score, etc.
981
+
982
+ Reference: https://github.ibm.com/MLT/LLMEvalKit
983
+ """
984
+
985
+ main_score = "overall_valid"
986
+ reduction = MeanReduction()
987
+ prediction_type = ToolCall
988
+ _requirements_list = {
989
+ "llmevalkit": "Install with \"pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'\".\nTo gain access please reach the team."
990
+ }
991
+ runtime_pipeline: bool = True # Whether to use the runtime pipeline or the longer evaluation pipeline with actionable recommendations
992
+
993
+ def prepare(self):
994
+ provider_to_default_reflector_model = {
995
+ "watsonx": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
996
+ "open-ai": "gpt-4o",
997
+ "rits": "openai/gpt-oss-120b",
998
+ "azure": "gpt-4o",
999
+ "mock": "mock",
1000
+ }
1001
+ provider = (
1002
+ settings.default_provider if not settings.mock_inference_mode else "mock"
1003
+ )
1004
+ if provider not in provider_to_default_reflector_model:
1005
+ raise ValueError(
1006
+ f"Unsupported provider for ReflectionToolCallingMetric: {provider}. Supported providers are: {list(provider_to_default_reflector_model.keys())}"
1007
+ )
1008
+ self.requirements = self._get_missing_requirements_by_provider(provider)
1009
+ super().prepare()
1010
+ self.setup_pipeline(
1011
+ reflector_model_name=provider_to_default_reflector_model.get(provider),
1012
+ provider_name=provider,
1013
+ )
1014
+
1015
+ def setup_pipeline(
1016
+ self, reflector_model_name: str, provider_name: Optional[str] = None
1017
+ ):
1018
+ if provider_name:
1019
+ llmeval_provider_name = self._get_llmeval_provider_name(provider_name)
1020
+ requirements = self._get_missing_requirements_by_provider(provider_name)
1021
+ self.check_missing_requirements(requirements)
1022
+
1023
+ metrics_client = self._get_metrics_client(
1024
+ llmeval_provider_name, reflector_model_name
1025
+ )
1026
+ self.reflection_pipeline = self._build_reflection_pipeline(metrics_client)
1027
+ return self.reflection_pipeline
1028
+
1029
+ @staticmethod
1030
+ def _get_llmeval_provider_name(provider_name: str) -> str:
1031
+ mapping = {
1032
+ "watsonx": "watsonx.output_val",
1033
+ "open-ai": "openai.async.output_val",
1034
+ "rits": "litellm.rits.output_val",
1035
+ "azure": "azure_openai.async.output_val",
1036
+ "mock": "mock.output_val",
1037
+ }
1038
+ llmeval_provider_name = mapping.get(provider_name)
1039
+ if llmeval_provider_name is None:
1040
+ raise ValueError(f"Unsupported provider by llmevalkit: {provider_name}")
1041
+ return llmeval_provider_name
1042
+
1043
+ @staticmethod
1044
+ def _get_missing_requirements_by_provider(provider_name: str):
1045
+ provider_libs = {
1046
+ "watsonx": "ibm_watsonx_ai",
1047
+ "open-ai": "openai",
1048
+ "rits": "litellm",
1049
+ "azure": "openai",
1050
+ }
1051
+ required_lib = provider_libs.get(provider_name)
1052
+ return [required_lib] if required_lib else []
1053
+
1054
+ @staticmethod
1055
+ def _get_metrics_client(llmeval_provider_name: str, reflector_model_name: str):
1056
+ from llmevalkit.llm import get_llm
1057
+
1058
+ metrics_client_cls = get_llm(llmeval_provider_name)
1059
+ return metrics_client_cls(model_name=reflector_model_name)
1060
+
1061
+ def _build_reflection_pipeline(self, metrics_client):
1062
+ from llmevalkit.function_calling.consts import (
1063
+ METRIC_AGENTIC_CONSTRAINTS_SATISFACTION,
1064
+ METRIC_FUNCTION_SELECTION_APPROPRIATENESS,
1065
+ METRIC_GENERAL_HALLUCINATION_CHECK,
1066
+ METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT,
1067
+ )
1068
+ from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline
1069
+
1070
+ return ReflectionPipeline(
1071
+ metrics_client=metrics_client,
1072
+ general_metrics=[
1073
+ METRIC_GENERAL_HALLUCINATION_CHECK,
1074
+ METRIC_GENERAL_VALUE_FORMAT_ALIGNMENT,
1075
+ ],
1076
+ function_metrics=[
1077
+ METRIC_FUNCTION_SELECTION_APPROPRIATENESS,
1078
+ METRIC_AGENTIC_CONSTRAINTS_SATISFACTION,
1079
+ ],
1080
+ parameter_metrics=[],
1081
+ runtime_pipeline=self.runtime_pipeline,
1082
+ )
1083
+
1084
+ async def map(
1085
+ self,
1086
+ prediction: ToolCall,
1087
+ references: None,
1088
+ task_data: Dict[str, Any],
1089
+ ):
1090
+ from llmevalkit.function_calling.pipeline.types import PipelineResult
1091
+
1092
+ # Convert unitxt dialog to LLMEvalKit format
1093
+ if "dialog" in task_data:
1094
+ conversation_history = [dict(turn) for turn in task_data["dialog"]]
1095
+ elif "query" in task_data:
1096
+ conversation_history = [{"role": "user", "content": task_data["query"]}]
1097
+ else:
1098
+ raise ValueError("task_data must contain either 'dialog' or 'query' field.")
1099
+
1100
+ # Convert unitxt tool inventory to LLMEvalKit format
1101
+ tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(
1102
+ task_data.get("tools", [])
1103
+ )
1104
+
1105
+ # Convert unitxt tool call to LLMEvalKit format
1106
+ tool_call_converted = ReflectionToolCallingMixin.convert_tool_call(prediction)
1107
+
1108
+ # Run reflection (syntactic + semantic)
1109
+ result: PipelineResult = await self.reflection_pipeline.run_async(
1110
+ conversation=conversation_history,
1111
+ inventory=tools_inventory,
1112
+ call=tool_call_converted,
1113
+ retries=3,
1114
+ continue_on_static=True,
1115
+ )
1116
+ result_dict = result.model_dump()
1117
+ result_dict["overall_valid"] = float(result_dict["overall_valid"])
1118
+ return result_dict
1119
+
1120
+ def map_stream(
1121
+ self,
1122
+ items: Iterable[Tuple[ToolCall, None, Dict[str, Any]]],
1123
+ *,
1124
+ max_concurrency: int = 8,
1125
+ ) -> List[Dict[str, Any]]:
1126
+ """Run self.map in parallel over an iterable and return results in order."""
1127
+
1128
+ async def process_all():
1129
+ items_iter = iter(enumerate(items))
1130
+ results = []
1131
+ pending = set()
1132
+ while True:
1133
+ while len(pending) < max_concurrency:
1134
+ try:
1135
+ idx, (pred, refs, data) = next(items_iter)
1136
+ if isinstance(pred, list):
1137
+ for p in pred:
1138
+ task = asyncio.create_task(self.map(p, refs, data))
1139
+ task.idx = idx
1140
+ pending.add(task)
1141
+ else:
1142
+ task = asyncio.create_task(self.map(pred, refs, data))
1143
+ task.idx = idx
1144
+ pending.add(task)
1145
+ except StopIteration:
1146
+ break
1147
+ if not pending:
1148
+ break
1149
+ done, pending = await asyncio.wait(
1150
+ pending, return_when=asyncio.FIRST_COMPLETED
1151
+ )
1152
+ for task in done:
1153
+ results.append((task.idx, await task))
1154
+ results.sort()
1155
+ return [r for _, r in results]
1156
+
1157
+ return asyncio.run(process_all())
1158
+
1159
+ def reduce_one(self, intermidate: Dict[str, Any]) -> Dict[str, float]:
1160
+ return intermidate
1161
+
1162
+ def reduce(self, intermidates: List[Dict[str, Any]]) -> Dict[str, float]:
1163
+ flat_instances = []
1164
+ for instance in intermidates:
1165
+ flat_instance_dict = {}
1166
+ for metric, metric_type_dict in (
1167
+ instance.get("static", {}).get("metrics", {}).items()
1168
+ ):
1169
+ flat_instance_dict[f"static_{metric}"] = float(
1170
+ metric_type_dict["valid"]
1171
+ )
1172
+
1173
+ for metric_type, metric_type_dict in instance.get("semantic", {}).items():
1174
+ if metric_type_dict is not None:
1175
+ for metric, metric_dict in metric_type_dict.get(
1176
+ "metrics", {}
1177
+ ).items():
1178
+ flat_instance_dict[f"semantic_{metric}"] = 1 - float(
1179
+ metric_dict["is_issue"]
1180
+ )
1181
+ flat_instance_dict[f"semantic_avg_score_{metric_type}"] = float(
1182
+ metric_type_dict.get("avg_score")
1183
+ )
1184
+
1185
+ flat_instance_dict["overall_valid"] = float(instance["overall_valid"])
1186
+ flat_instances.append(flat_instance_dict)
1187
+
1188
+ return self.reduction.reduce(flat_instances)
1189
+
1190
+
1191
+ class ReflectionToolCallingMetricSyntactic(
1192
+ ReductionInstanceMetric[str, Dict[str, float]]
1193
+ ):
1194
+ """Measures syntactic and schema validity of tool calls.
1195
+
1196
+ Range: [0, 1] (higher indicates less errors).
1197
+ Returns 1.0 if the tool call is valid for each metric, 0.0 otherwise.
1198
+ overall_valid equals 1.0 if all metrics are valid, 0.0 otherwise.
1199
+ Global score is the percentage of valid instances across the dataset.
1200
+
1201
+ Scores:
1202
+ - non_existent_function: tool name not found.
1203
+ - non_existent_parameter: argument name not in tool spec.
1204
+ - incorrect_parameter_type: argument type mismatch.
1205
+ - missing_required_parameter: required argument missing.
1206
+ - allowed_values_violation: argument value outside allowed set.
1207
+ - json_schema_violation: call violates JSON schema.
1208
+ - empty_api_spec: no tool spec provided.
1209
+ - invalid_api_spec: tool spec is invalid.
1210
+ - invalid_tool_call: call is not a valid tool invocation.
1211
+ - overall_valid: validity of the call (main score).
1212
+ - score: alias of overall_valid.
1213
+
1214
+ Reference: https://github.ibm.com/MLT/LLMEvalKit
1215
+ """
1216
+
1217
+ main_score = "overall_valid"
1218
+ reduction = MeanReduction()
1219
+ prediction_type = ToolCall
1220
+ _requirements_list = {
1221
+ "llmevalkit": "Install with \"pip install 'git+ssh://git@github.ibm.com/MLT/LLMEvalKit.git'\".\nTo gain access please reach the team."
1222
+ }
1223
+
1224
+ def map(
1225
+ self,
1226
+ prediction: ToolCall,
1227
+ references: None,
1228
+ task_data: Dict[str, Any],
1229
+ ) -> Dict[str, float]:
1230
+ from llmevalkit.function_calling.pipeline.pipeline import ReflectionPipeline
1231
+
1232
+ # Convert unitxt tool inventory to LLMEvalKit format
1233
+ tools_inventory = ReflectionToolCallingMixin.convert_tools_inventory(
1234
+ task_data.get("tools", [])
1235
+ )
1236
+
1237
+ # Convert unitxt tool call to LLMEvalKit format
1238
+ tool_call = ReflectionToolCallingMixin.convert_tool_call(prediction)
1239
+
1240
+ # Run static validation
1241
+ static_result = ReflectionPipeline.static_only(tools_inventory, tool_call)
1242
+
1243
+ result_dict = static_result.model_dump()
1244
+ result_dict["overall_valid"] = float(result_dict.pop("final_decision"))
1245
+ result_dict["metrics"]["json_schema_violation"] = result_dict["metrics"].pop(
1246
+ "json_schema_validation"
1247
+ )
1248
+ return result_dict
1249
+
1250
+ def reduce_one(self, intermidate: Dict[str, float]) -> Dict[str, float]:
1251
+ return intermidate
1252
+
1253
+ def reduce(self, intermediates: List[Dict[str, float]]) -> Dict[str, float]:
1254
+ flat_instances = []
1255
+ for instance in intermediates:
1256
+ flat_instance_dict = {}
1257
+ for metric, metric_dict in instance.get("metrics", {}).items():
1258
+ flat_instance_dict[metric] = float(metric_dict["valid"])
1259
+ flat_instance_dict["overall_valid"] = float(instance["overall_valid"])
1260
+ flat_instances.append(flat_instance_dict)
1261
+
1262
+ return self.reduction.reduce(flat_instances)
1263
+
1264
+
1265
  class MetricWithConfidenceInterval(Metric):
1266
  # The number of resamples used to estimate the confidence intervals of this metric.
1267
  # Use None to disable confidence interval computation.
 
5260
  # single_reference_per_prediction = True
5261
  max_context_words = 4096
5262
  reduction_map = {"mean": [main_score]}
5263
+ model = None
5264
 
5265
  _requirements_list: List[str] = ["transformers", "torch"]
5266
 
5267
+ def load_model(self):
 
 
5268
  import torch
5269
 
5270
  if torch.cuda.is_available():
 
5282
  model_path, trust_remote_code=True
5283
  ).to(device)
5284
 
5285
+ @retry_connection_with_exponential_backoff(backoff_factor=2)
5286
+ def prepare(self):
5287
+ super().prepare()
5288
+ # load_model() moved from prepare() to compute() because model is gated in HF
5289
+
5290
  def compute(
5291
  self,
5292
  references: List[List[Any]],
 
5295
  ) -> List[Dict[str, Any]]:
5296
  from tqdm import tqdm
5297
 
5298
+ if self.model is None:
5299
+ self.load_model()
5300
  # treat the references as the contexts and the predictions as answers
5301
  # concat references
5302
 
operators.py CHANGED
@@ -40,6 +40,7 @@ General Operators List:
40
  """
41
 
42
  import operator
 
43
  import uuid
44
  import warnings
45
  import zipfile
@@ -2611,3 +2612,105 @@ class ReadFile(FieldOperator):
2611
  # Read from local file
2612
  with open(value, encoding=self.encoding) as f:
2613
  return f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
41
 
42
  import operator
43
+ import re
44
  import uuid
45
  import warnings
46
  import zipfile
 
2612
  # Read from local file
2613
  with open(value, encoding=self.encoding) as f:
2614
  return f.read()
2615
+
2616
+
2617
+ class FixJsonSchemaOfParameterTypes(InstanceOperator):
2618
+ main_field: str
2619
+
2620
+ def prepare(self):
2621
+ self.simple_mapping = {
2622
+ "": "object",
2623
+ "any": "object",
2624
+ "Any": "object",
2625
+ "Array": "array",
2626
+ "ArrayList": "array",
2627
+ "Bigint": "integer",
2628
+ "bool": "boolean",
2629
+ "Boolean": "boolean",
2630
+ "byte": "integer",
2631
+ "char": "string",
2632
+ "dict": "object",
2633
+ "Dict": "object",
2634
+ "double": "number",
2635
+ "float": "number",
2636
+ "HashMap": "object",
2637
+ "Hashtable": "object",
2638
+ "int": "integer",
2639
+ "list": "array",
2640
+ "List": "array",
2641
+ "long": "integer",
2642
+ "Queue": "array",
2643
+ "short": "integer",
2644
+ "Stack": "array",
2645
+ "tuple": "array",
2646
+ "Set": "array",
2647
+ "set": "array",
2648
+ "str": "string",
2649
+ "String": "string",
2650
+ }
2651
+
2652
+ def dict_type_of(self, type_str: str) -> dict:
2653
+ return {"type": type_str}
2654
+
2655
+ def recursive_trace_for_type_fields(self, containing_element):
2656
+ if isinstance(containing_element, dict):
2657
+ keys = list(containing_element.keys())
2658
+ for key in keys:
2659
+ if key == "type" and isinstance(containing_element["type"], str):
2660
+ jsonschema_dict = self.type_str_to_jsonschema_dict(
2661
+ containing_element["type"]
2662
+ )
2663
+ containing_element.pop("type")
2664
+ containing_element.update(jsonschema_dict)
2665
+ else:
2666
+ self.recursive_trace_for_type_fields(containing_element[key])
2667
+ elif isinstance(containing_element, list):
2668
+ for list_element in containing_element:
2669
+ self.recursive_trace_for_type_fields(list_element)
2670
+
2671
+ def type_str_to_jsonschema_dict(self, type_str: str) -> dict:
2672
+ if type_str in self.simple_mapping:
2673
+ return self.dict_type_of(self.simple_mapping[type_str])
2674
+ m = re.match(r"^(List|Tuple)\[(.*?)\]$", type_str)
2675
+ if m:
2676
+ basic_type = self.dict_type_of("array")
2677
+ basic_type["items"] = self.type_str_to_jsonschema_dict(
2678
+ m.group(2) if m.group(1) == "List" else m.group(2).split(",")[0].strip()
2679
+ )
2680
+ return basic_type
2681
+
2682
+ m = re.match(r"^(Union)\[(.*?)\]$", type_str)
2683
+ if m:
2684
+ args = m.group(2).split(",")
2685
+ for i in range(len(args)):
2686
+ args[i] = args[i].strip()
2687
+ return {"anyOf": [self.type_str_to_jsonschema_dict(arg) for arg in args]}
2688
+ if re.match(r"^(Callable)\[(.*?)\]$", type_str):
2689
+ return self.dict_type_of("object")
2690
+ if "," in type_str:
2691
+ sub_types = type_str.split(",")
2692
+ for i in range(len(sub_types)):
2693
+ sub_types[i] = sub_types[i].strip()
2694
+ assert len(sub_types) in [
2695
+ 2,
2696
+ 3,
2697
+ ], f"num of subtypes should be 2 or 3, got {type_str}"
2698
+ basic_type = self.type_str_to_jsonschema_dict(sub_types[0])
2699
+ for sub_type in sub_types[1:]:
2700
+ if sub_type.lower().startswith("default"):
2701
+ basic_type["default"] = re.split(r"[= ]", sub_type, maxsplit=1)[1]
2702
+ for sub_type in sub_types[1:]:
2703
+ if sub_type.lower().startswith("optional"):
2704
+ return {"anyOf": [basic_type, self.dict_type_of("null")]}
2705
+ return basic_type
2706
+
2707
+ return self.dict_type_of(type_str) # otherwise - return what arrived
2708
+
2709
+ def process(
2710
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
2711
+ ) -> Dict[str, Any]:
2712
+ assert (
2713
+ self.main_field in instance
2714
+ ), f"field '{self.main_field}' must reside in instance in order to verify its jsonschema correctness. got {instance}"
2715
+ self.recursive_trace_for_type_fields(instance[self.main_field])
2716
+ return instance
splitters.py CHANGED
@@ -312,11 +312,14 @@ class DiverseLabelsSampler(Sampler):
312
  sample_size: int,
313
  instances_pool: List[Dict[str, object]],
314
  instance: Optional[Dict[str, object]],
 
315
  ) -> List[Dict[str, object]]:
316
  if self.labels_cache is None:
317
  self.labels_cache = self.divide_by_repr(instances_pool)
318
  all_labels = list(self.labels_cache.keys())
319
- random_generator = get_random_generator_based_on_instance(instance)
 
 
320
  random_generator.shuffle(all_labels)
321
  from collections import Counter
322
 
 
312
  sample_size: int,
313
  instances_pool: List[Dict[str, object]],
314
  instance: Optional[Dict[str, object]],
315
+ sampling_seed: Optional[int] = None,
316
  ) -> List[Dict[str, object]]:
317
  if self.labels_cache is None:
318
  self.labels_cache = self.divide_by_repr(instances_pool)
319
  all_labels = list(self.labels_cache.keys())
320
+ random_generator = get_random_generator_based_on_instance(
321
+ instance, local_seed=sampling_seed
322
+ )
323
  random_generator.shuffle(all_labels)
324
  from collections import Counter
325
 
text2sql_utils.py CHANGED
@@ -728,7 +728,7 @@ def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
728
  return False
729
  return True
730
  except Exception as e:
731
- logger.debug(f"Errpr parsing SQL query for comparison: {e}")
732
  return False
733
 
734
 
@@ -863,6 +863,8 @@ def compare_dfs_ignore_colnames_subset(
863
  if df1.empty or df2.empty or len(df1) != len(df2):
864
  return False
865
 
 
 
866
  subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
867
 
868
  if ignore_row_order:
 
728
  return False
729
  return True
730
  except Exception as e:
731
+ logger.debug(f"Error parsing SQL query for comparison: {e}")
732
  return False
733
 
734
 
 
863
  if df1.empty or df2.empty or len(df1) != len(df2):
864
  return False
865
 
866
+ df1.columns = range(df1.shape[1])
867
+ df2.columns = range(df2.shape[1])
868
  subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
869
 
870
  if ignore_row_order:
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.26.6"
 
1
+ version = "1.26.7"